use of edu.cmu.tetrad.search.BDeuScore in project tetrad by cmu-phil.
the class HsimAutoRun method run.
// ***********Public methods*************//
public double[] run(int resimSize) {
// modify this so that verbose is a private data value, and so that data can be taken from either a dataset or a file.
// ===========read data from file=============
Set<String> eVars = new HashSet<String>();
eVars.add("MULT");
double[] output;
output = new double[5];
try {
// ==== try with BigDataSetUtility ==============
// DataSet regularDataSet = BigDataSetUtility.readInDiscreteData(new File(readfilename), delimiter, eVars);
// ======done with BigDataSetUtility=============
// if (verbose) System.out.println("Regular cols: " + regularDataSet.getNumColumns() + " rows: " + regularDataSet.getNumRows());
// testing the read file
// DataWriter.writeRectangularData(dataSet, new FileWriter("dataOut2.txt"), '\t');
// apply Hsim to data, with whatever parameters
// ========first make the Dag for Hsim==========
BDeuScore score = new BDeuScore(data);
// ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(dataSet);
double penaltyDiscount = 2.0;
Fges fges = new Fges(score);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setPenaltyDiscount(penaltyDiscount);
Graph estGraph = fges.search();
// if (verbose) System.out.println(estGraph);
Graph estPattern = new EdgeListGraphSingleConnections(estGraph);
PatternToDag patternToDag = new PatternToDag(estPattern);
Graph estGraphDAG = patternToDag.patternToDagMeek();
Dag estDAG = new Dag(estGraphDAG);
// ===========Identify the nodes to be resimulated===========
// select a random node as the centroid
List<Node> allNodes = estGraph.getNodes();
int size = allNodes.size();
int randIndex = new Random().nextInt(size);
Node centroid = allNodes.get(randIndex);
if (verbose) {
System.out.println("the centroid is " + centroid);
}
List<Node> queue = new ArrayList<>();
queue.add(centroid);
List<Node> queueAdd = new ArrayList<Node>();
// if (verbose) System.out.println(queue);
while (queue.size() < resimSize) {
// if (verbose) System.out.println(queue.size() + " vs " + resimSize);
// find nodes adjacent to nodes in current queue, add them to a queue without duplicating nodes
int qsize = queue.size();
for (int i = 0; i < qsize; i++) {
// find set of adjacent nodes
queueAdd = estGraph.getAdjacentNodes(queue.get(i));
// remove nodes that are already in queue
queueAdd.removeAll(queue);
// //**** If queueAdd is empty at this stage, randomly select a node to add
while (queueAdd.size() < 1) {
queueAdd.add(allNodes.get(new Random().nextInt(size)));
}
// add remaining nodes to queue
queue.addAll(queueAdd);
// break early when queue outgrows resimsize
if (queue.size() >= resimSize) {
break;
}
}
}
// if queue is too big, remove nodes from the end until it is small enough.
while (queue.size() > resimSize) {
queue.remove(queue.size() - 1);
// if (verbose) System.out.println(queue);
}
Set<Node> simnodes = new HashSet<Node>(queue);
if (verbose) {
System.out.println("the resimmed nodes are " + simnodes);
}
// ===========Apply the hybrid resimulation===============
// regularDataSet
Hsim hsim = new Hsim(estDAG, simnodes, data);
DataSet newDataSet = hsim.hybridsimulate();
// write output to a new file
if (write) {
FileWriter fileWriter = new FileWriter(filenameOut);
DataWriter.writeRectangularData(newDataSet, fileWriter, delimiter);
fileWriter.close();
}
// =======Run FGES on the output data, and compare it to the original learned graph
// Path dataFileOut = Paths.get(filenameOut);
// edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter);
// DataSet dataSetOut = dataReaderOut.readInData(eVars);
BDeuScore newscore = new BDeuScore(newDataSet);
Fges fgesOut = new Fges(newscore);
fgesOut.setVerbose(false);
fgesOut.setNumPatternsToStore(0);
fgesOut.setPenaltyDiscount(2.0);
// fgesOut.setOut(out);
// fgesOut.setFaithfulnessAssumed(true);
// fgesOut.setMaxIndegree(1);
// fgesOut.setCycleBound(5);
Graph estGraphOut = fgesOut.search();
// if (verbose) System.out.println(" bugchecking: fges estGraphOut: " + estGraphOut);
// doing the replaceNodes trick to fix some bugs
estGraphOut = GraphUtils.replaceNodes(estGraphOut, estDAG.getNodes());
// restrict the comparison to the simnodes and edges to their parents
Set<Node> allParents = HsimUtils.getAllParents(estGraphOut, simnodes);
Set<Node> addParents = HsimUtils.getAllParents(estDAG, simnodes);
allParents.addAll(addParents);
Graph estEvalGraphOut = HsimUtils.evalEdges(estGraphOut, simnodes, allParents);
Graph estEvalGraph = HsimUtils.evalEdges(estDAG, simnodes, allParents);
// SearchGraphUtils.graphComparison(estGraph, estGraphOut, System.out);
estEvalGraphOut = GraphUtils.replaceNodes(estEvalGraphOut, estEvalGraph.getNodes());
// if (verbose) System.out.println(estEvalGraph);
// if (verbose) System.out.println(estEvalGraphOut);
// SearchGraphUtils.graphComparison(estEvalGraphOut, estEvalGraph, System.out);
output = HsimUtils.errorEval(estEvalGraphOut, estEvalGraph);
if (verbose) {
System.out.println(output[0] + " " + output[1] + " " + output[2] + " " + output[3] + " " + output[4]);
}
} catch (Exception IOException) {
IOException.printStackTrace();
}
return output;
}
use of edu.cmu.tetrad.search.BDeuScore in project tetrad by cmu-phil.
the class HsimRobustCompare method run.
// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
// public static void main(String[] args) {
// first generate the data
RandomUtil.getInstance().setSeed(1450184147770L);
// '\t';
char delimiter = ',';
final int numEdges = (int) (numVars * edgesPerNode);
List<Node> vars = new ArrayList<>();
double[] oErrors = new double[5];
double[] hsimErrors = new double[5];
double[] simErrors = new double[5];
List<double[]> output = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
BayesPm bayesPm = new BayesPm(odag, 2, 2);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
// oData is the original data set, and odag is the original dag.
DataSet oData = bayesIm.simulateData(numCases, false);
// System.out.println(oData);
// System.out.println(odag);
// then run FGES
BDeuScore oscore = new BDeuScore(oData);
Fges fges = new Fges(oscore);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setPenaltyDiscount(penaltyDiscount);
Graph oGraphOut = fges.search();
if (verbose)
System.out.println(oGraphOut);
// calculate FGES errors
oErrors = new double[5];
oErrors = HsimUtils.errorEval(oGraphOut, odag);
if (verbose)
System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
// create various simulated data sets
// //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
PatternToDag pickdag = new PatternToDag(oGraphOut);
Graph fgesDag = pickdag.patternToDagMeek();
Dag fgesdag2 = new Dag(fgesDag);
BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
DirichletEstimator simEstimator = new DirichletEstimator();
DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
DataSet simData = fittedIM.simulateData(numCases, false);
// //next let's do a schedule of small hsims
HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
hsimErrors = study.run(resimSize, repeat);
// calculate errors for all simulated output graphs
// //full simulation errors first
BDeuScore simscore = new BDeuScore(simData);
Fges simfges = new Fges(simscore);
simfges.setVerbose(false);
simfges.setNumPatternsToStore(0);
simfges.setPenaltyDiscount(penaltyDiscount);
Graph simGraphOut = simfges.search();
// simErrors = new double[5];
simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
// first, let's just see what the errors are.
if (verbose)
System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
if (verbose)
System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
if (verbose)
System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
// then, let's try to squeeze these numbers down into something more tractable.
// double[] ErrorDifferenceDifferences;
// ErrorDifferenceDifferences = new double[5];
// ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
// ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
// ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
// ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
// ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
// System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
output.add(oErrors);
output.add(simErrors);
output.add(hsimErrors);
return output;
}
Aggregations