use of edu.cmu.tetrad.search.Mimbuild2 in project tetrad by cmu-phil.
the class TestPurify method test2.
@Test
public void test2() {
RandomUtil.getInstance().setSeed(48290483L);
Graph graph = new EdgeListGraph(DataGraphUtils.randomSingleFactorModel(3, 3, 5, 0, 0, 0));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
List<Node> latents = new ArrayList<>();
for (Node node : graph.getNodes()) {
if (node.getNodeType() == NodeType.LATENT)
latents.add(node);
}
Graph structuralGraph = graph.subgraph(latents);
List<List<Node>> clustering = new ArrayList<>();
for (Node node : latents) {
List<Node> adj = graph.getAdjacentNodes(node);
adj.removeAll(latents);
clustering.add(adj);
}
ContinuousTetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.001);
IPurify purify = new PurifyTetradBased2(test);
List<List<Node>> purifiedClustering = purify.purify(clustering);
List<String> latentsNames = new ArrayList<>();
for (int i = 0; i < latents.size(); i++) {
latentsNames.add(latents.get(i).getName());
}
Mimbuild2 mimbuild = new Mimbuild2();
mimbuild.setAlpha(0.0001);
Graph _graph = mimbuild.search(purifiedClustering, latentsNames, new CovarianceMatrix(data));
List<Node> _latents = new ArrayList<>();
for (Node node : _graph.getNodes()) {
if (node.getNodeType() == NodeType.LATENT)
_latents.add(node);
}
Graph _structuralGraph = _graph.subgraph(_latents);
assertEquals(2, _structuralGraph.getNumEdges());
}
use of edu.cmu.tetrad.search.Mimbuild2 in project tetrad by cmu-phil.
the class TestMimbuild2 method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(49283494L);
for (int r = 0; r < 1; r++) {
Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0);
Graph mimStructure = structure(mim);
Parameters params = new Parameters();
params.set("coefLow", .5);
params.set("coefHigh", 1.5);
SemPm pm = new SemPm(mim);
SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(300, false);
String algorithm = "FOFC";
Graph searchGraph;
List<List<Node>> partition;
if (algorithm.equals("FOFC")) {
FindOneFactorClusters fofc = new FindOneFactorClusters(data, TestType.TETRAD_WISHART, FindOneFactorClusters.Algorithm.GAP, 0.001);
searchGraph = fofc.search();
partition = fofc.getClusters();
} else if (algorithm.equals("BPC")) {
TestType testType = TestType.TETRAD_WISHART;
TestType purifyType = TestType.TETRAD_BASED;
BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
searchGraph = bpc.search();
partition = MimUtils.convertToClusters2(searchGraph);
} else {
throw new IllegalStateException();
}
List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
// System.out.println(partition);
// System.out.println(latentVarList);
//
// System.out.println("True\n" + mimStructure);
Graph mimbuildStructure;
for (int mimbuildMethod : new int[] { 2 }) {
if (mimbuildMethod == 2) {
Mimbuild2 mimbuild = new Mimbuild2();
mimbuild.setAlpha(0.001);
mimbuild.setMinClusterSize(3);
mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
assertEquals(7, shd);
} else if (mimbuildMethod == 3) {
// System.out.println("Mimbuild Trek\n");
MimbuildTrek mimbuild = new MimbuildTrek();
mimbuild.setAlpha(0.1);
mimbuild.setMinClusterSize(3);
mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
// ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
// System.out.println("\nCovariance over the latents");
// System.out.println(latentcov);
// System.out.println("Estimated\n" + mimbuildStructure);
int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
// System.out.println("SHD = " + shd);
// System.out.println();
assertEquals(3, shd);
} else {
throw new IllegalStateException();
}
}
}
}
use of edu.cmu.tetrad.search.Mimbuild2 in project tetrad by cmu-phil.
the class MimBuildRunner method execute.
// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() throws Exception {
DataSet data = this.dataSet;
Mimbuild2 mimbuild = new Mimbuild2();
mimbuild.setAlpha(getParams().getDouble("alpha", 0.001));
mimbuild.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
if (getParams().getBoolean("includeThreeClusters", true)) {
mimbuild.setMinClusterSize(3);
} else {
mimbuild.setMinClusterSize(4);
}
Clusters clusters = (Clusters) getParams().get("clusters", null);
List<List<Node>> partition = ClusterUtils.clustersToPartition(clusters, data.getVariables());
List<String> latentNames = new ArrayList<>();
for (int i = 0; i < clusters.getNumClusters(); i++) {
latentNames.add(clusters.getClusterName(i));
}
CovarianceMatrix cov = new CovarianceMatrix(data);
Graph structureGraph = mimbuild.search(partition, latentNames, cov);
GraphUtils.circleLayout(structureGraph, 200, 200, 150);
GraphUtils.fruchtermanReingoldLayout(structureGraph);
ICovarianceMatrix latentsCov = mimbuild.getLatentsCov();
TetradLogger.getInstance().log("details", "Latent covs = \n" + latentsCov);
Graph fullGraph = mimbuild.getFullGraph();
GraphUtils.circleLayout(fullGraph, 200, 200, 150);
GraphUtils.fruchtermanReingoldLayout(fullGraph);
setResultGraph(fullGraph);
setFullGraph(fullGraph);
setClusters(MimUtils.convertToClusters(structureGraph));
setClusters(ClusterUtils.partitionToClusters(mimbuild.getClustering()));
setStructureGraph(structureGraph);
getParams().set("latentVariableNames", new ArrayList<>(latentNames));
this.covMatrix = latentsCov;
double p = mimbuild.getpValue();
TetradLogger.getInstance().log("details", "\nStructure graph = " + structureGraph);
TetradLogger.getInstance().log("details", getLatentClustersString(fullGraph).toString());
TetradLogger.getInstance().log("details", "P = " + p);
if (getParams().getBoolean("showMaxP", false)) {
if (p > getParams().getDouble("maxP", 1.0)) {
getParams().set("maxP", p);
getParams().set("maxStructureGraph", structureGraph);
getParams().set("maxClusters", getClusters());
getParams().set("maxFullGraph", fullGraph);
getParams().set("maxAlpha", getParams().getDouble("alpha", 0.001));
}
setStructureGraph((Graph) getParams().get("maxStructureGraph", null));
setFullGraph((Graph) getParams().get("maxFullGraph", null));
if (getParams().get("maxClusters", null) != null) {
setClusters((Clusters) getParams().get("maxClusters", null));
}
setResultGraph((Graph) getParams().get("maxFullGraph", null));
TetradLogger.getInstance().log("maxmodel", "\nMAX Graph = " + getParams().get("maxStructureGraph", null));
TetradLogger.getInstance().log("maxmodel", getLatentClustersString((Graph) getParams().get("maxFullGraph", null)).toString());
TetradLogger.getInstance().log("maxmodel", "MAX P = " + getParams().getDouble("maxP", 1.0));
}
}