use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.
the class IonSearchEditor method addSpecialMenus.
protected void addSpecialMenus(JMenuBar menuBar) {
if (!(getAlgorithmRunner() instanceof IGesRunner)) {
JMenu test = new JMenu("Independence");
menuBar.add(test);
IndTestMenuItems.addIndependenceTestChoices(test, this);
// test.addSeparator();
//
// AlgorithmRunner algorithmRunner = getAlgorithmRunner();
// if (algorithmRunner instanceof IndTestProducer) {
// IndTestProducer p = (IndTestProducer) algorithmRunner;
// IndependenceFactsAction action =
// new IndependenceFactsAction(this, p, "Independence Facts...");
// test.add(action);
// }
}
JMenu graph = new JMenu("Graph");
JMenuItem showDags = new JMenuItem("Show DAGs in forbid_latent_common_causes");
// JMenuItem meekOrient = new JMenuItem("Meek Orientation");
JMenuItem dagInPattern = new JMenuItem("Choose DAG in forbid_latent_common_causes");
JMenuItem gesOrient = new JMenuItem("Global Score-based Reorientation");
JMenuItem nextGraph = new JMenuItem("Next Graph");
JMenuItem previousGraph = new JMenuItem("Previous Graph");
// graph.add(new LayoutMenu(this));
graph.add(new GraphPropertiesAction(getWorkbench()));
graph.add(new PathsAction(getWorkbench()));
// graph.add(new DirectedPathsAction(getWorkbench()));
// graph.add(new TreksAction(getWorkbench()));
// graph.add(new AllPathsAction(getWorkbench()));
// graph.add(new NeighborhoodsAction(getWorkbench()));
graph.add(new TriplesAction(getWorkbench().getGraph(), getAlgorithmRunner()));
graph.addSeparator();
// graph.add(meekOrient);
graph.add(dagInPattern);
graph.add(gesOrient);
graph.addSeparator();
graph.add(previousGraph);
graph.add(nextGraph);
graph.addSeparator();
graph.add(showDags);
graph.addSeparator();
graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench())));
graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench())));
menuBar.add(graph);
showDags.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Window owner = (Window) getTopLevelAncestor();
new WatchedProcess(owner) {
public void watch() {
// Needs to be a pattern search; this isn't checked
// before running the algorithm because of allowable
// "slop"--e.g. bidirected edges.
AlgorithmRunner runner = getAlgorithmRunner();
Graph graph = runner.getGraph();
if (graph == null) {
JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "No result gaph.");
return;
}
// if (runner instanceof ImagesRunner) {
// GraphScorer scorer = ((ImagesRunner) runner).getGraphScorer();
// Graph _graph = ((ImagesRunner) runner).getTopGraphs().get(getIndex()).getGraph();
//
// ScoredGraphsDisplay display = new ScoredGraphsDisplay(_graph, scorer);
// GraphWorkbench workbench = getWorkbench();
//
// EditorWindow editorWindow =
// new EditorWindow(display, "Independence Facts",
// "Close", false, workbench);
// DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
// editorWindow.setVisible(true);
// }
// else {
PatternDisplay display = new PatternDisplay(graph);
GraphWorkbench workbench = getWorkbench();
EditorWindow editorWindow = new EditorWindow(display, "Independence Facts", "Close", false, workbench);
DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
editorWindow.setVisible(true);
// }
}
};
}
});
// meekOrient.addActionListener(new ActionListener() {
// public void actionPerformed(ActionEvent e) {
// ImpliedOrientation rules = getAlgorithmRunner().getMeekRules();
// rules.setKnowledge((IKnowledge) getAlgorithmRunner().getParams().get("knowledge", new Knowledge2()));
// rules.orientImplied(getGraph());
// getGraphHistory().add(getGraph());
// getWorkbench().setGraph(getGraph());
// firePropertyChange("modelChanged", null, null);
// }
// });
dagInPattern.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Graph graph = new EdgeListGraph(getGraph());
// Removing bidirected edges from the pattern before selecting a DAG. 4
for (Edge edge : graph.getEdges()) {
if (Edges.isBidirectedEdge(edge)) {
graph.removeEdge(edge);
}
}
PatternToDag search = new PatternToDag(new EdgeListGraphSingleConnections(graph));
Graph dag = search.patternToDagMeek();
getGraphHistory().add(dag);
getWorkbench().setGraph(dag);
((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(dag);
firePropertyChange("modelChanged", null, null);
}
});
gesOrient.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
DataModel dataModel = getAlgorithmRunner().getDataModel();
final Graph graph = SearchGraphUtils.reorient(getGraph(), dataModel, getKnowledge());
getGraphHistory().add(graph);
getWorkbench().setGraph(graph);
firePropertyChange("modelChanged", null, null);
}
});
nextGraph.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Graph next = getGraphHistory().next();
getWorkbench().setGraph(next);
((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(next);
firePropertyChange("modelChanged", null, null);
}
});
previousGraph.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Graph previous = getGraphHistory().previous();
getWorkbench().setGraph(previous);
((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(previous);
firePropertyChange("modelChanged", null, null);
}
});
// if (getAlgorithmRunner().supportsKnowledge()) {
// menuBar.add(new Knowledge2Menu(this));
// }
menuBar.add(new LayoutMenu(this));
}
use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.
the class FciCcdSearchEditor method addSpecialMenus.
protected void addSpecialMenus(JMenuBar menuBar) {
if (!(getAlgorithmRunner() instanceof IGesRunner)) {
JMenu test = new JMenu("Independence");
menuBar.add(test);
IndTestMenuItems.addIndependenceTestChoices(test, this);
// test.addSeparator();
//
// AlgorithmRunner algorithmRunner = getAlgorithmRunner();
// if (algorithmRunner instanceof IndTestProducer) {
// IndTestProducer p = (IndTestProducer) algorithmRunner;
// IndependenceFactsAction action =
// new IndependenceFactsAction(this, p, "Independence Facts...");
// test.add(action);
// }
}
JMenu graph = new JMenu("Graph");
JMenuItem showDags = new JMenuItem("Show DAGs in forbid_latent_common_causes");
// JMenuItem meekOrient = new JMenuItem("Meek Orientation");
JMenuItem dagInPattern = new JMenuItem("Choose DAG in forbid_latent_common_causes");
JMenuItem gesOrient = new JMenuItem("Global Score-based Reorientation");
JMenuItem nextGraph = new JMenuItem("Next Graph");
JMenuItem previousGraph = new JMenuItem("Previous Graph");
// graph.add(new LayoutMenu(this));
graph.add(new GraphPropertiesAction(getWorkbench()));
graph.add(new PathsAction(getWorkbench()));
// graph.add(new DirectedPathsAction(getWorkbench()));
// graph.add(new TreksAction(getWorkbench()));
// graph.add(new AllPathsAction(getWorkbench()));
// graph.add(new NeighborhoodsAction(getWorkbench()));
graph.add(new TriplesAction(getWorkbench().getGraph(), getAlgorithmRunner()));
graph.addSeparator();
// graph.add(meekOrient);
graph.add(dagInPattern);
graph.add(gesOrient);
graph.addSeparator();
graph.add(previousGraph);
graph.add(nextGraph);
graph.addSeparator();
graph.add(showDags);
graph.addSeparator();
graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench())));
graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench())));
menuBar.add(graph);
showDags.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Window owner = (Window) getTopLevelAncestor();
new WatchedProcess(owner) {
public void watch() {
// Needs to be a pattern search; this isn't checked
// before running the algorithm because of allowable
// "slop"--e.g. bidirected edges.
AlgorithmRunner runner = getAlgorithmRunner();
Graph graph = runner.getGraph();
if (graph == null) {
JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "No result gaph.");
return;
}
// if (runner instanceof ImagesRunner) {
// GraphScorer scorer = ((ImagesRunner) runner).getGraphScorer();
// Graph _graph = ((ImagesRunner) runner).getTopGraphs().get(getIndex()).getGraph();
//
// ScoredGraphsDisplay display = new ScoredGraphsDisplay(_graph, scorer);
// GraphWorkbench workbench = getWorkbench();
//
// EditorWindow editorWindow =
// new EditorWindow(display, "Independence Facts",
// "Close", false, workbench);
// DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
// editorWindow.setVisible(true);
// }
// else {
PatternDisplay display = new PatternDisplay(graph);
GraphWorkbench workbench = getWorkbench();
EditorWindow editorWindow = new EditorWindow(display, "Independence Facts", "Close", false, workbench);
DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
editorWindow.setVisible(true);
// }
}
};
}
});
// meekOrient.addActionListener(new ActionListener() {
// public void actionPerformed(ActionEvent e) {
// ImpliedOrientation rules = getAlgorithmRunner().getMeekRules();
// rules.setKnowledge((IKnowledge) getAlgorithmRunner().getParams().get("knowledge", new Knowledge2()));
// rules.orientImplied(getGraph());
// getGraphHistory().add(getGraph());
// getWorkbench().setGraph(getGraph());
// firePropertyChange("modelChanged", null, null);
// }
// });
dagInPattern.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Graph graph = new EdgeListGraph(getGraph());
// Removing bidirected edges from the pattern before selecting a DAG. 4
for (Edge edge : graph.getEdges()) {
if (Edges.isBidirectedEdge(edge)) {
graph.removeEdge(edge);
}
}
PatternToDag search = new PatternToDag(new EdgeListGraphSingleConnections(graph));
Graph dag = search.patternToDagMeek();
getGraphHistory().add(dag);
getWorkbench().setGraph(dag);
((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(dag);
firePropertyChange("modelChanged", null, null);
}
});
gesOrient.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
DataModel dataModel = getAlgorithmRunner().getDataModel();
final Graph graph = SearchGraphUtils.reorient(getGraph(), dataModel, getKnowledge());
getGraphHistory().add(graph);
getWorkbench().setGraph(graph);
firePropertyChange("modelChanged", null, null);
}
});
nextGraph.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Graph next = getGraphHistory().next();
getWorkbench().setGraph(next);
((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(next);
firePropertyChange("modelChanged", null, null);
}
});
previousGraph.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
Graph previous = getGraphHistory().previous();
getWorkbench().setGraph(previous);
((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(previous);
firePropertyChange("modelChanged", null, null);
}
});
// if (getAlgorithmRunner().supportsKnowledge()) {
// menuBar.add(new Knowledge2Menu(this));
// }
menuBar.add(new LayoutMenu(this));
}
use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.
the class HsimAutoRun method run.
// ***********Public methods*************//
public double[] run(int resimSize) {
// modify this so that verbose is a private data value, and so that data can be taken from either a dataset or a file.
// ===========read data from file=============
Set<String> eVars = new HashSet<String>();
eVars.add("MULT");
double[] output;
output = new double[5];
try {
// ==== try with BigDataSetUtility ==============
// DataSet regularDataSet = BigDataSetUtility.readInDiscreteData(new File(readfilename), delimiter, eVars);
// ======done with BigDataSetUtility=============
// if (verbose) System.out.println("Regular cols: " + regularDataSet.getNumColumns() + " rows: " + regularDataSet.getNumRows());
// testing the read file
// DataWriter.writeRectangularData(dataSet, new FileWriter("dataOut2.txt"), '\t');
// apply Hsim to data, with whatever parameters
// ========first make the Dag for Hsim==========
BDeuScore score = new BDeuScore(data);
// ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(dataSet);
double penaltyDiscount = 2.0;
Fges fges = new Fges(score);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setPenaltyDiscount(penaltyDiscount);
Graph estGraph = fges.search();
// if (verbose) System.out.println(estGraph);
Graph estPattern = new EdgeListGraphSingleConnections(estGraph);
PatternToDag patternToDag = new PatternToDag(estPattern);
Graph estGraphDAG = patternToDag.patternToDagMeek();
Dag estDAG = new Dag(estGraphDAG);
// ===========Identify the nodes to be resimulated===========
// select a random node as the centroid
List<Node> allNodes = estGraph.getNodes();
int size = allNodes.size();
int randIndex = new Random().nextInt(size);
Node centroid = allNodes.get(randIndex);
if (verbose) {
System.out.println("the centroid is " + centroid);
}
List<Node> queue = new ArrayList<>();
queue.add(centroid);
List<Node> queueAdd = new ArrayList<Node>();
// if (verbose) System.out.println(queue);
while (queue.size() < resimSize) {
// if (verbose) System.out.println(queue.size() + " vs " + resimSize);
// find nodes adjacent to nodes in current queue, add them to a queue without duplicating nodes
int qsize = queue.size();
for (int i = 0; i < qsize; i++) {
// find set of adjacent nodes
queueAdd = estGraph.getAdjacentNodes(queue.get(i));
// remove nodes that are already in queue
queueAdd.removeAll(queue);
// //**** If queueAdd is empty at this stage, randomly select a node to add
while (queueAdd.size() < 1) {
queueAdd.add(allNodes.get(new Random().nextInt(size)));
}
// add remaining nodes to queue
queue.addAll(queueAdd);
// break early when queue outgrows resimsize
if (queue.size() >= resimSize) {
break;
}
}
}
// if queue is too big, remove nodes from the end until it is small enough.
while (queue.size() > resimSize) {
queue.remove(queue.size() - 1);
// if (verbose) System.out.println(queue);
}
Set<Node> simnodes = new HashSet<Node>(queue);
if (verbose) {
System.out.println("the resimmed nodes are " + simnodes);
}
// ===========Apply the hybrid resimulation===============
// regularDataSet
Hsim hsim = new Hsim(estDAG, simnodes, data);
DataSet newDataSet = hsim.hybridsimulate();
// write output to a new file
if (write) {
FileWriter fileWriter = new FileWriter(filenameOut);
DataWriter.writeRectangularData(newDataSet, fileWriter, delimiter);
fileWriter.close();
}
// =======Run FGES on the output data, and compare it to the original learned graph
// Path dataFileOut = Paths.get(filenameOut);
// edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter);
// DataSet dataSetOut = dataReaderOut.readInData(eVars);
BDeuScore newscore = new BDeuScore(newDataSet);
Fges fgesOut = new Fges(newscore);
fgesOut.setVerbose(false);
fgesOut.setNumPatternsToStore(0);
fgesOut.setPenaltyDiscount(2.0);
// fgesOut.setOut(out);
// fgesOut.setFaithfulnessAssumed(true);
// fgesOut.setMaxIndegree(1);
// fgesOut.setCycleBound(5);
Graph estGraphOut = fgesOut.search();
// if (verbose) System.out.println(" bugchecking: fges estGraphOut: " + estGraphOut);
// doing the replaceNodes trick to fix some bugs
estGraphOut = GraphUtils.replaceNodes(estGraphOut, estDAG.getNodes());
// restrict the comparison to the simnodes and edges to their parents
Set<Node> allParents = HsimUtils.getAllParents(estGraphOut, simnodes);
Set<Node> addParents = HsimUtils.getAllParents(estDAG, simnodes);
allParents.addAll(addParents);
Graph estEvalGraphOut = HsimUtils.evalEdges(estGraphOut, simnodes, allParents);
Graph estEvalGraph = HsimUtils.evalEdges(estDAG, simnodes, allParents);
// SearchGraphUtils.graphComparison(estGraph, estGraphOut, System.out);
estEvalGraphOut = GraphUtils.replaceNodes(estEvalGraphOut, estEvalGraph.getNodes());
// if (verbose) System.out.println(estEvalGraph);
// if (verbose) System.out.println(estEvalGraphOut);
// SearchGraphUtils.graphComparison(estEvalGraphOut, estEvalGraph, System.out);
output = HsimUtils.errorEval(estEvalGraphOut, estEvalGraph);
if (verbose) {
System.out.println(output[0] + " " + output[1] + " " + output[2] + " " + output[3] + " " + output[4]);
}
} catch (Exception IOException) {
IOException.printStackTrace();
}
return output;
}
use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.
the class HsimEvalFromData method main.
public static void main(String[] args) {
long timestart = System.nanoTime();
System.out.println("Beginning Evaluation");
String nl = System.lineSeparator();
String output = "Simulation edu.cmu.tetrad.study output comparing Fsim and Hsim on predicting graph discovery accuracy" + nl;
int iterations = 100;
int vars = 20;
int cases = 500;
int edgeratio = 3;
List<Integer> hsimRepeat = Arrays.asList(40);
List<Integer> fsimRepeat = Arrays.asList(40);
List<PRAOerrors>[] fsimErrsByPars = new ArrayList[fsimRepeat.size()];
int whichFrepeat = 0;
for (int frepeat : fsimRepeat) {
fsimErrsByPars[whichFrepeat] = new ArrayList<PRAOerrors>();
whichFrepeat++;
}
List<PRAOerrors>[][] hsimErrsByPars = new ArrayList[1][hsimRepeat.size()];
// System.out.println(resimSize.size()+" "+hsimRepeat.size());
int whichHrepeat;
whichHrepeat = 0;
for (int hrepeat : hsimRepeat) {
// System.out.println(whichrsize+" "+whichHrepeat);
hsimErrsByPars[0][whichHrepeat] = new ArrayList<PRAOerrors>();
whichHrepeat++;
}
// !(*%(@!*^!($%!^ START ITERATING HERE !#$%(*$#@!^(*!$*%(!$#
try {
for (int iterate = 0; iterate < iterations; iterate++) {
System.out.println("iteration " + iterate);
// @#$%@$%^@$^@$^@%$%@$#^ LOADING THE DATA AND GRAPH @$#%%*#^##*^$#@%$
DataSet data1;
Graph graph1 = GraphUtils.loadGraphTxt(new File("graph/graph.1.txt"));
Dag odag = new Dag(graph1);
Set<String> eVars = new HashSet<String>();
eVars.add("MULT");
Path dataFile = Paths.get("data/data.1.txt");
TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), Delimiter.TAB);
data1 = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData(eVars));
vars = data1.getNumColumns();
cases = data1.getNumRows();
edgeratio = 3;
// !#@^$@&%^!#$!&@^ CALCULATING TARGET ERRORS $%$#@^@!%!#^$!%$#%
ICovarianceMatrix newcov = new CovarianceMatrixOnTheFly(data1);
SemBicScore oscore = new SemBicScore(newcov);
Fges ofgs = new Fges(oscore);
ofgs.setVerbose(false);
ofgs.setNumPatternsToStore(0);
// ***********This is the original FGS output on the data
Graph oFGSGraph = ofgs.search();
PRAOerrors oErrors = new PRAOerrors(HsimUtils.errorEval(oFGSGraph, odag), "target errors");
// **then step 1: full resim. iterate through the combinations of estimator parameters (just repeat num)
for (whichFrepeat = 0; whichFrepeat < fsimRepeat.size(); whichFrepeat++) {
ArrayList<PRAOerrors> errorsList = new ArrayList<PRAOerrors>();
for (int r = 0; r < fsimRepeat.get(whichFrepeat); r++) {
PatternToDag pickdag = new PatternToDag(oFGSGraph);
Graph fgsDag = pickdag.patternToDagMeek();
Dag fgsdag2 = new Dag(fgsDag);
// then fit an IM to this dag and the data. GeneralizedSemEstimator seems to bug out
// GeneralizedSemPm simSemPm = new GeneralizedSemPm(fgsdag2);
// GeneralizedSemEstimator gsemEstimator = new GeneralizedSemEstimator();
// GeneralizedSemIm fittedIM = gsemEstimator.estimate(simSemPm, oData);
SemPm simSemPm = new SemPm(fgsdag2);
// BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm);
SemEstimator simSemEstimator = new SemEstimator(data1, simSemPm);
SemIm fittedIM = simSemEstimator.estimate();
DataSet simData = fittedIM.simulateData(data1.getNumRows(), false);
// after making the full resim data (simData), run FGS on that
ICovarianceMatrix simcov = new CovarianceMatrixOnTheFly(simData);
SemBicScore simscore = new SemBicScore(simcov);
Fges simfgs = new Fges(simscore);
simfgs.setVerbose(false);
simfgs.setNumPatternsToStore(0);
Graph simGraphOut = simfgs.search();
PRAOerrors simErrors = new PRAOerrors(HsimUtils.errorEval(simGraphOut, fgsdag2), "Fsim errors " + r);
errorsList.add(simErrors);
}
PRAOerrors avErrors = new PRAOerrors(errorsList, "Average errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
// if (verbosity>3) System.out.println(avErrors.allToString());
// ****calculate the squared errors of prediction, store all these errors in a list
double FsimAR2 = (avErrors.getAdjRecall() - oErrors.getAdjRecall()) * (avErrors.getAdjRecall() - oErrors.getAdjRecall());
double FsimAP2 = (avErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (avErrors.getAdjPrecision() - oErrors.getAdjPrecision());
double FsimOR2 = (avErrors.getOrientRecall() - oErrors.getOrientRecall()) * (avErrors.getOrientRecall() - oErrors.getOrientRecall());
double FsimOP2 = (avErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (avErrors.getOrientPrecision() - oErrors.getOrientPrecision());
PRAOerrors Fsim2 = new PRAOerrors(new double[] { FsimAR2, FsimAP2, FsimOR2, FsimOP2 }, "squared errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
// add the fsim squared errors to the appropriate list
fsimErrsByPars[whichFrepeat].add(Fsim2);
}
// **then step 2: hybrid sim. iterate through combos of params (repeat num, resimsize)
for (whichHrepeat = 0; whichHrepeat < hsimRepeat.size(); whichHrepeat++) {
HsimRepeatAC study = new HsimRepeatAC(data1);
PRAOerrors HsimErrors = new PRAOerrors(study.run(1, hsimRepeat.get(whichHrepeat)), "Hsim errors" + "at rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
// ****calculate the squared errors of prediction
double HsimAR2 = (HsimErrors.getAdjRecall() - oErrors.getAdjRecall()) * (HsimErrors.getAdjRecall() - oErrors.getAdjRecall());
double HsimAP2 = (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision());
double HsimOR2 = (HsimErrors.getOrientRecall() - oErrors.getOrientRecall()) * (HsimErrors.getOrientRecall() - oErrors.getOrientRecall());
double HsimOP2 = (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision());
PRAOerrors Hsim2 = new PRAOerrors(new double[] { HsimAR2, HsimAP2, HsimOR2, HsimOP2 }, "squared errors for Hsim, rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
hsimErrsByPars[0][whichHrepeat].add(Hsim2);
}
}
// Average the squared errors for each set of fsim/hsim params across all iterations
PRAOerrors[] fMSE = new PRAOerrors[fsimRepeat.size()];
PRAOerrors[][] hMSE = new PRAOerrors[1][hsimRepeat.size()];
String[][] latexTableArray = new String[1 * hsimRepeat.size() + fsimRepeat.size()][5];
for (int j = 0; j < fMSE.length; j++) {
fMSE[j] = new PRAOerrors(fsimErrsByPars[j], "MSE for Fsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " frepeat=" + fsimRepeat.get(j) + " iterations=" + iterations);
// if(verbosity>0){System.out.println(fMSE[j].allToString());}
output = output + fMSE[j].allToString() + nl;
latexTableArray[j] = prelimToPRAOtable(fMSE[j]);
}
for (int j = 0; j < hMSE.length; j++) {
for (int k = 0; k < hMSE[j].length; k++) {
hMSE[j][k] = new PRAOerrors(hsimErrsByPars[j][k], "MSE for Hsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " rsize=" + 1 + " repeat=" + hsimRepeat.get(k) + " iterations=" + iterations);
// if(verbosity>0){System.out.println(hMSE[j][k].allToString());}
output = output + hMSE[j][k].allToString() + nl;
latexTableArray[fsimRepeat.size() + j * hMSE[j].length + k] = prelimToPRAOtable(hMSE[j][k]);
}
}
// record all the params, the base error values, and the fsim/hsim mean squared errors
String latexTable = HsimUtils.makeLatexTable(latexTableArray);
PrintWriter writer = new PrintWriter("latexTable.txt", "UTF-8");
writer.println(latexTable);
writer.close();
PrintWriter writer2 = new PrintWriter("HvsF-SimulationEvaluation.txt", "UTF-8");
writer2.println(output);
writer2.close();
long timestop = System.nanoTime();
System.out.println("Evaluation Concluded. Duration: " + (timestop - timestart) / 1000000000 + "s");
} catch (Exception IOException) {
IOException.printStackTrace();
}
}
use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.
the class HsimRobustCompare method run.
// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
// public static void main(String[] args) {
// first generate the data
RandomUtil.getInstance().setSeed(1450184147770L);
// '\t';
char delimiter = ',';
final int numEdges = (int) (numVars * edgesPerNode);
List<Node> vars = new ArrayList<>();
double[] oErrors = new double[5];
double[] hsimErrors = new double[5];
double[] simErrors = new double[5];
List<double[]> output = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
BayesPm bayesPm = new BayesPm(odag, 2, 2);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
// oData is the original data set, and odag is the original dag.
DataSet oData = bayesIm.simulateData(numCases, false);
// System.out.println(oData);
// System.out.println(odag);
// then run FGES
BDeuScore oscore = new BDeuScore(oData);
Fges fges = new Fges(oscore);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setPenaltyDiscount(penaltyDiscount);
Graph oGraphOut = fges.search();
if (verbose)
System.out.println(oGraphOut);
// calculate FGES errors
oErrors = new double[5];
oErrors = HsimUtils.errorEval(oGraphOut, odag);
if (verbose)
System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
// create various simulated data sets
// //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
PatternToDag pickdag = new PatternToDag(oGraphOut);
Graph fgesDag = pickdag.patternToDagMeek();
Dag fgesdag2 = new Dag(fgesDag);
BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
DirichletEstimator simEstimator = new DirichletEstimator();
DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
DataSet simData = fittedIM.simulateData(numCases, false);
// //next let's do a schedule of small hsims
HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
hsimErrors = study.run(resimSize, repeat);
// calculate errors for all simulated output graphs
// //full simulation errors first
BDeuScore simscore = new BDeuScore(simData);
Fges simfges = new Fges(simscore);
simfges.setVerbose(false);
simfges.setNumPatternsToStore(0);
simfges.setPenaltyDiscount(penaltyDiscount);
Graph simGraphOut = simfges.search();
// simErrors = new double[5];
simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
// first, let's just see what the errors are.
if (verbose)
System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
if (verbose)
System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
if (verbose)
System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
// then, let's try to squeeze these numbers down into something more tractable.
// double[] ErrorDifferenceDifferences;
// ErrorDifferenceDifferences = new double[5];
// ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
// ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
// ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
// ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
// ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
// System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
output.add(oErrors);
output.add(simErrors);
output.add(hsimErrors);
return output;
}
Aggregations