Search in sources :

Example 1 with Mimbuild2

use of edu.cmu.tetrad.search.Mimbuild2 in project tetrad by cmu-phil.

the class TestPurify method test2.

@Test
public void test2() {
    RandomUtil.getInstance().setSeed(48290483L);
    Graph graph = new EdgeListGraph(DataGraphUtils.randomSingleFactorModel(3, 3, 5, 0, 0, 0));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(1000, false);
    List<Node> latents = new ArrayList<>();
    for (Node node : graph.getNodes()) {
        if (node.getNodeType() == NodeType.LATENT)
            latents.add(node);
    }
    Graph structuralGraph = graph.subgraph(latents);
    List<List<Node>> clustering = new ArrayList<>();
    for (Node node : latents) {
        List<Node> adj = graph.getAdjacentNodes(node);
        adj.removeAll(latents);
        clustering.add(adj);
    }
    ContinuousTetradTest test = new ContinuousTetradTest(data, TestType.TETRAD_WISHART, 0.001);
    IPurify purify = new PurifyTetradBased2(test);
    List<List<Node>> purifiedClustering = purify.purify(clustering);
    List<String> latentsNames = new ArrayList<>();
    for (int i = 0; i < latents.size(); i++) {
        latentsNames.add(latents.get(i).getName());
    }
    Mimbuild2 mimbuild = new Mimbuild2();
    mimbuild.setAlpha(0.0001);
    Graph _graph = mimbuild.search(purifiedClustering, latentsNames, new CovarianceMatrix(data));
    List<Node> _latents = new ArrayList<>();
    for (Node node : _graph.getNodes()) {
        if (node.getNodeType() == NodeType.LATENT)
            _latents.add(node);
    }
    Graph _structuralGraph = _graph.subgraph(_latents);
    assertEquals(2, _structuralGraph.getNumEdges());
}
Also used : Mimbuild2(edu.cmu.tetrad.search.Mimbuild2) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) List(java.util.List) ArrayList(java.util.ArrayList) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 2 with Mimbuild2

use of edu.cmu.tetrad.search.Mimbuild2 in project tetrad by cmu-phil.

the class TestMimbuild2 method test1.

@Test
public void test1() {
    RandomUtil.getInstance().setSeed(49283494L);
    for (int r = 0; r < 1; r++) {
        Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0);
        Graph mimStructure = structure(mim);
        Parameters params = new Parameters();
        params.set("coefLow", .5);
        params.set("coefHigh", 1.5);
        SemPm pm = new SemPm(mim);
        SemIm im = new SemIm(pm, params);
        DataSet data = im.simulateData(300, false);
        String algorithm = "FOFC";
        Graph searchGraph;
        List<List<Node>> partition;
        if (algorithm.equals("FOFC")) {
            FindOneFactorClusters fofc = new FindOneFactorClusters(data, TestType.TETRAD_WISHART, FindOneFactorClusters.Algorithm.GAP, 0.001);
            searchGraph = fofc.search();
            partition = fofc.getClusters();
        } else if (algorithm.equals("BPC")) {
            TestType testType = TestType.TETRAD_WISHART;
            TestType purifyType = TestType.TETRAD_BASED;
            BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
            searchGraph = bpc.search();
            partition = MimUtils.convertToClusters2(searchGraph);
        } else {
            throw new IllegalStateException();
        }
        List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
        // System.out.println(partition);
        // System.out.println(latentVarList);
        // 
        // System.out.println("True\n" + mimStructure);
        Graph mimbuildStructure;
        for (int mimbuildMethod : new int[] { 2 }) {
            if (mimbuildMethod == 2) {
                Mimbuild2 mimbuild = new Mimbuild2();
                mimbuild.setAlpha(0.001);
                mimbuild.setMinClusterSize(3);
                mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
                int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
                assertEquals(7, shd);
            } else if (mimbuildMethod == 3) {
                // System.out.println("Mimbuild Trek\n");
                MimbuildTrek mimbuild = new MimbuildTrek();
                mimbuild.setAlpha(0.1);
                mimbuild.setMinClusterSize(3);
                mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
                // ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
                // System.out.println("\nCovariance over the latents");
                // System.out.println(latentcov);
                // System.out.println("Estimated\n" + mimbuildStructure);
                int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
                // System.out.println("SHD = " + shd);
                // System.out.println();
                assertEquals(3, shd);
            } else {
                throw new IllegalStateException();
            }
        }
    }
}
Also used : Mimbuild2(edu.cmu.tetrad.search.Mimbuild2) Parameters(edu.cmu.tetrad.util.Parameters) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 3 with Mimbuild2

use of edu.cmu.tetrad.search.Mimbuild2 in project tetrad by cmu-phil.

the class MimBuildRunner method execute.

// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() throws Exception {
    DataSet data = this.dataSet;
    Mimbuild2 mimbuild = new Mimbuild2();
    mimbuild.setAlpha(getParams().getDouble("alpha", 0.001));
    mimbuild.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
    if (getParams().getBoolean("includeThreeClusters", true)) {
        mimbuild.setMinClusterSize(3);
    } else {
        mimbuild.setMinClusterSize(4);
    }
    Clusters clusters = (Clusters) getParams().get("clusters", null);
    List<List<Node>> partition = ClusterUtils.clustersToPartition(clusters, data.getVariables());
    List<String> latentNames = new ArrayList<>();
    for (int i = 0; i < clusters.getNumClusters(); i++) {
        latentNames.add(clusters.getClusterName(i));
    }
    CovarianceMatrix cov = new CovarianceMatrix(data);
    Graph structureGraph = mimbuild.search(partition, latentNames, cov);
    GraphUtils.circleLayout(structureGraph, 200, 200, 150);
    GraphUtils.fruchtermanReingoldLayout(structureGraph);
    ICovarianceMatrix latentsCov = mimbuild.getLatentsCov();
    TetradLogger.getInstance().log("details", "Latent covs = \n" + latentsCov);
    Graph fullGraph = mimbuild.getFullGraph();
    GraphUtils.circleLayout(fullGraph, 200, 200, 150);
    GraphUtils.fruchtermanReingoldLayout(fullGraph);
    setResultGraph(fullGraph);
    setFullGraph(fullGraph);
    setClusters(MimUtils.convertToClusters(structureGraph));
    setClusters(ClusterUtils.partitionToClusters(mimbuild.getClustering()));
    setStructureGraph(structureGraph);
    getParams().set("latentVariableNames", new ArrayList<>(latentNames));
    this.covMatrix = latentsCov;
    double p = mimbuild.getpValue();
    TetradLogger.getInstance().log("details", "\nStructure graph = " + structureGraph);
    TetradLogger.getInstance().log("details", getLatentClustersString(fullGraph).toString());
    TetradLogger.getInstance().log("details", "P = " + p);
    if (getParams().getBoolean("showMaxP", false)) {
        if (p > getParams().getDouble("maxP", 1.0)) {
            getParams().set("maxP", p);
            getParams().set("maxStructureGraph", structureGraph);
            getParams().set("maxClusters", getClusters());
            getParams().set("maxFullGraph", fullGraph);
            getParams().set("maxAlpha", getParams().getDouble("alpha", 0.001));
        }
        setStructureGraph((Graph) getParams().get("maxStructureGraph", null));
        setFullGraph((Graph) getParams().get("maxFullGraph", null));
        if (getParams().get("maxClusters", null) != null) {
            setClusters((Clusters) getParams().get("maxClusters", null));
        }
        setResultGraph((Graph) getParams().get("maxFullGraph", null));
        TetradLogger.getInstance().log("maxmodel", "\nMAX Graph = " + getParams().get("maxStructureGraph", null));
        TetradLogger.getInstance().log("maxmodel", getLatentClustersString((Graph) getParams().get("maxFullGraph", null)).toString());
        TetradLogger.getInstance().log("maxmodel", "MAX P = " + getParams().getDouble("maxP", 1.0));
    }
}
Also used : Mimbuild2(edu.cmu.tetrad.search.Mimbuild2) ArrayList(java.util.ArrayList) Graph(edu.cmu.tetrad.graph.Graph) ArrayList(java.util.ArrayList) List(java.util.List)

Aggregations

Mimbuild2 (edu.cmu.tetrad.search.Mimbuild2)3 SemIm (edu.cmu.tetrad.sem.SemIm)2 SemPm (edu.cmu.tetrad.sem.SemPm)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 Test (org.junit.Test)2 Graph (edu.cmu.tetrad.graph.Graph)1 Parameters (edu.cmu.tetrad.util.Parameters)1