Search in sources :

Example 11 with DMSearch

use of edu.cmu.tetrad.search.DMSearch in project tetrad by cmu-phil.

the class TestDM method test3.

@Test
public void test3() {
    // setting seed for debug.
    RandomUtil.getInstance().setSeed(29483818483L);
    Graph graph = emptyGraph(12);
    graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X2"));
    graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X3"));
    graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X2"));
    graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X3"));
    graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X6"));
    graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X7"));
    graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X6"));
    graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X7"));
    graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X10"));
    graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X11"));
    graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X10"));
    graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X11"));
    graph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("X6"));
    graph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("X7"));
    graph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("X6"));
    graph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("X7"));
    graph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("X10"));
    graph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("X11"));
    graph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("X10"));
    graph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("X11"));
    graph.addDirectedEdge(new ContinuousVariable("X8"), new ContinuousVariable("X10"));
    graph.addDirectedEdge(new ContinuousVariable("X8"), new ContinuousVariable("X11"));
    graph.addDirectedEdge(new ContinuousVariable("X9"), new ContinuousVariable("X10"));
    graph.addDirectedEdge(new ContinuousVariable("X9"), new ContinuousVariable("X11"));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(100000, false);
    DMSearch search = new DMSearch();
    search.setInputs(new int[] { 0, 1, 4, 5, 8, 9 });
    search.setOutputs(new int[] { 2, 3, 6, 7, 10, 11 });
    search.setData(data);
    search.setTrueInputs(search.getInputs());
    Graph foundGraph = search.search();
    print("Test Case 3");
    Graph trueGraph = new EdgeListGraph();
    trueGraph.addNode(new ContinuousVariable("X0"));
    trueGraph.addNode(new ContinuousVariable("X1"));
    trueGraph.addNode(new ContinuousVariable("X2"));
    trueGraph.addNode(new ContinuousVariable("X3"));
    trueGraph.addNode(new ContinuousVariable("X4"));
    trueGraph.addNode(new ContinuousVariable("X5"));
    trueGraph.addNode(new ContinuousVariable("X6"));
    trueGraph.addNode(new ContinuousVariable("X7"));
    trueGraph.addNode(new ContinuousVariable("X8"));
    trueGraph.addNode(new ContinuousVariable("X9"));
    trueGraph.addNode(new ContinuousVariable("X10"));
    trueGraph.addNode(new ContinuousVariable("X11"));
    trueGraph.addNode(new ContinuousVariable("L0"));
    trueGraph.addNode(new ContinuousVariable("L1"));
    trueGraph.addNode(new ContinuousVariable("L2"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L1"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L1"));
    trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X2"));
    trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X3"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L2"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L2"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("L2"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("L2"));
    trueGraph.addDirectedEdge(new ContinuousVariable("L2"), new ContinuousVariable("X6"));
    trueGraph.addDirectedEdge(new ContinuousVariable("L2"), new ContinuousVariable("X7"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L0"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L0"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("L0"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("L0"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X8"), new ContinuousVariable("L0"));
    trueGraph.addDirectedEdge(new ContinuousVariable("X9"), new ContinuousVariable("L0"));
    trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X10"));
    trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X11"));
    assertTrue(foundGraph.equals(trueGraph));
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) DMSearch(edu.cmu.tetrad.search.DMSearch) Test(org.junit.Test)

Example 12 with DMSearch

use of edu.cmu.tetrad.search.DMSearch in project tetrad by cmu-phil.

the class TestDM method internaltest9.

@Ignore
public int internaltest9(double initialDiscount) {
    RandomUtil.getInstance().setSeed(29483818483L);
    int nInputs = 17610;
    int nOutputs = 12042;
    int[] inputs = new int[nInputs];
    int[] outputs = new int[nOutputs];
    for (int i = 0; i < nInputs; i++) {
        inputs[i] = i;
    }
    for (int i = 0; i < nOutputs; i++) {
        outputs[i] = nInputs + i - 1;
    }
    print("test 9");
    // Trying recursion as while loop seems to reduce speed below that of non-loop version.
    // double initialDiscount = 20;
    // while(initialDiscount>0){
    DMSearch result = new DMSearch();
    result.setAlphaPC(.000001);
    result.setAlphaSober(.000001);
    result.setDiscount(initialDiscount);
    result = readAndSearchData("src/edu/cmu/tetradproj/amurrayw/final_joined_data_no_par.txt", inputs, outputs, true, inputs);
    print("Finished search, now writing output to file.");
    File file = new File("src/edu/cmu/tetradproj/amurrayw/final_output_" + initialDiscount + "_.txt");
    try {
        FileOutputStream out = new FileOutputStream(file);
        PrintStream outStream = new PrintStream(out);
        outStream.println(result.getDmStructure().latentStructToEdgeListGraph(result.getDmStructure()));
    // outStream.println();
    } catch (java.io.FileNotFoundException e) {
        print("Can't write to file.");
    }
    File file2 = new File("src/edu/cmu/tetradproj/amurrayw/unconverted_output" + initialDiscount + "_.txt");
    try {
        FileOutputStream out = new FileOutputStream(file2);
        PrintStream outStream = new PrintStream(out);
        outStream.println(result.getDmStructure());
    // outStream.println();
    } catch (java.io.FileNotFoundException e) {
        print("Can't write to file.");
    }
    // initialDiscount--;
    // }
    // System.out.println(result.getDmStructure().latentStructToEdgeListGraph(result.getDmStructure()));
    print("DONE");
    // }
    return (1);
}
Also used : PrintStream(java.io.PrintStream) FileOutputStream(java.io.FileOutputStream) File(java.io.File) DMSearch(edu.cmu.tetrad.search.DMSearch) Ignore(org.junit.Ignore)

Example 13 with DMSearch

use of edu.cmu.tetrad.search.DMSearch in project tetrad by cmu-phil.

the class TestDM method readAndSearchData.

// Reads in data and runs search. Note: Assumes variable names are of the form X0, X1, etc.
// Both input and output integer arrays are the indexes of their respective variables.
public DMSearch readAndSearchData(String fileLocation, int[] inputs, int[] outputs, boolean useGES, int[] trueInputs) {
    File file = new File(fileLocation);
    DataSet data = null;
    try {
        TabularDataReader dataReader = new ContinuousTabularDataFileReader(file, Delimiter.SPACE);
        data = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
    } catch (IOException e) {
        print("Failed to read in data.");
        e.printStackTrace();
    }
    print("Read Data");
    DMSearch search = new DMSearch();
    search.setInputs(inputs);
    search.setOutputs(outputs);
    if (useGES == false) {
        search.setAlphaPC(.05);
        search.setUseFges(useGES);
        search.setData(data);
        search.setTrueInputs(trueInputs);
        search.search();
    } else {
        search.setData(data);
        search.setTrueInputs(trueInputs);
        search.search();
    // search.search(data, trueInputs);
    }
    return (search);
}
Also used : TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) IOException(java.io.IOException) File(java.io.File) DMSearch(edu.cmu.tetrad.search.DMSearch)

Example 14 with DMSearch

use of edu.cmu.tetrad.search.DMSearch in project tetrad by cmu-phil.

the class TestDM method internaltest17.

@Ignore
public int internaltest17(double initialDiscount) {
    RandomUtil.getInstance().setSeed(29483818483L);
    int nInputs = 17610;
    int nOutputs = 12042;
    int[] inputs = new int[nInputs];
    int[] outputs = new int[nOutputs];
    int[] trueInputs = new int[] { 2761, 2762, 4450, 2247, 16137, 13108, 12530, 231, 1223, 1379, 5379, 12745, 14913, 16066, 16197, 16199, 17353, 17392, 4397, 3009, 3143, 5478, 5479, 5480, 5481, 5482, 7474, 12884, 12885, 12489, 9112, 1943, 9114, 1950, 9644, 9645, 9647 };
    for (int i = 0; i < nInputs; i++) {
        inputs[i] = i;
    }
    for (int i = 0; i < nOutputs; i++) {
        outputs[i] = nInputs + i - 1;
    }
    print("test 17");
    // Trying recursion as while loop seems to reduce speed below that of non-loop version.
    // double initialDiscount = 20;
    // while(initialDiscount>0){
    DMSearch result = new DMSearch();
    // result.setAlphaPC(1e-30);
    // result.setAlphaSober(1e-30);
    result.setAlphaPC(1e-6);
    result.setAlphaSober(1e-6);
    result.setDiscount(initialDiscount);
    result = readAndSearchData("src/edu/cmu/tetradproj/amurrayw/final_joined_data_no_par_fixed.txt", inputs, outputs, false, trueInputs);
    print("Finished search, now writing output to file.");
    File file = new File("src/edu/cmu/tetradproj/amurrayw/final_output_" + initialDiscount + "_.txt");
    try {
        FileOutputStream out = new FileOutputStream(file);
        PrintStream outStream = new PrintStream(out);
        outStream.println(result.getDmStructure().latentStructToEdgeListGraph(result.getDmStructure()));
    // outStream.println();
    } catch (java.io.FileNotFoundException e) {
        print("Can't write to file.");
    }
    File file2 = new File("src/edu/cmu/tetradproj/amurrayw/unconverted_output" + initialDiscount + "_.txt");
    try {
        FileOutputStream out = new FileOutputStream(file2);
        PrintStream outStream = new PrintStream(out);
        outStream.println(result.getDmStructure());
    // outStream.println();
    } catch (java.io.FileNotFoundException e) {
        print("Can't write to file.");
    }
    // initialDiscount--;
    // }
    // System.out.println(result.getDmStructure().latentStructToEdgeListGraph(result.getDmStructure()));
    print("DONE");
    // }
    return (1);
}
Also used : PrintStream(java.io.PrintStream) FileOutputStream(java.io.FileOutputStream) File(java.io.File) DMSearch(edu.cmu.tetrad.search.DMSearch) Ignore(org.junit.Ignore)

Aggregations

DMSearch (edu.cmu.tetrad.search.DMSearch)14 DataSet (edu.cmu.tetrad.data.DataSet)10 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)9 Graph (edu.cmu.tetrad.graph.Graph)9 SemIm (edu.cmu.tetrad.sem.SemIm)9 SemPm (edu.cmu.tetrad.sem.SemPm)9 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)7 Test (org.junit.Test)7 Ignore (org.junit.Ignore)6 File (java.io.File)5 FileOutputStream (java.io.FileOutputStream)4 PrintStream (java.io.PrintStream)4 Node (edu.cmu.tetrad.graph.Node)2 ContinuousTabularDataFileReader (edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader)1 TabularDataReader (edu.pitt.dbmi.data.reader.tabular.TabularDataReader)1 IOException (java.io.IOException)1