Search in sources :

Example 6 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark.

// @Test
public void testClark() {
    double f = .1;
    int N = 512;
    double alpha = 1.0;
    double penaltyDiscount = 1.0;
    for (int i = 0; i < 100; i++) {
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(z, x);
            g.addDirectedEdge(z, y);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(x, z);
            g.addDirectedEdge(y, z);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 7 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class GraphWrapper method getStrongestInfluenceGraph.

private static Graph getStrongestInfluenceGraph(GeneralizedSemIm im) {
    GeneralizedSemPm pm = im.getGeneralizedSemPm();
    Graph imGraph = im.getGeneralizedSemPm().getGraph();
    List<Node> nodes = new ArrayList<>();
    for (Node node : imGraph.getNodes()) {
        if (!(node.getNodeType() == NodeType.ERROR)) {
            nodes.add(node);
        }
    }
    Graph graph2 = new EdgeListGraph(nodes);
    for (Edge edge : imGraph.getEdges()) {
        Node node1 = edge.getNode1();
        Node node2 = edge.getNode2();
        if (!graph2.containsNode(node1))
            continue;
        if (!graph2.containsNode(node2))
            continue;
        if (graph2.isAdjacentTo(node1, node2)) {
            continue;
        }
        List<Edge> edges = imGraph.getEdges(node1, node2);
        if (edges.size() == 1) {
            graph2.addEdge(edges.get(0));
        } else {
            Expression expression1 = pm.getNodeExpression(node1);
            Expression expression2 = pm.getNodeExpression(node2);
            String param1 = findParameter(expression1, node2.getName());
            String param2 = findParameter(expression2, node1.getName());
            if (param1 == null || param2 == null) {
                continue;
            }
            double value1 = im.getParameterValue(param1);
            double value2 = im.getParameterValue(param2);
            if (value2 > value1) {
                graph2.addDirectedEdge(node1, node2);
            } else if (value1 > value2) {
                graph2.addDirectedEdge(node2, node1);
            }
        }
    }
    return graph2;
}
Also used : VariableExpression(edu.cmu.tetrad.calculator.expression.VariableExpression) Expression(edu.cmu.tetrad.calculator.expression.Expression) ArrayList(java.util.ArrayList) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 8 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class LeeHastieSimulation method simulate.

private DataSet simulate(Graph dag, Parameters parameters) {
    HashMap<String, Integer> nd = new HashMap<>();
    List<Node> nodes = dag.getNodes();
    Collections.shuffle(nodes);
    if (this.shuffledOrder == null) {
        List<Node> shuffledNodes = new ArrayList<>(nodes);
        Collections.shuffle(shuffledNodes);
        this.shuffledOrder = shuffledNodes;
    }
    for (int i = 0; i < nodes.size(); i++) {
        if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
            final int minNumCategories = parameters.getInt("numCategories");
            final int maxNumCategories = parameters.getInt("numCategories");
            final int value = pickNumCategories(minNumCategories, maxNumCategories);
            nd.put(shuffledOrder.get(i).getName(), value);
        } else {
            nd.put(shuffledOrder.get(i).getName(), 0);
        }
    }
    Graph graph = MixedUtils.makeMixedGraph(dag, nd);
    GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(graph, "Split(-1.5,-.5,.5,1.5)");
    GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    DataSet ds = im.simulateDataAvoidInfinity(parameters.getInt("sampleSize"), saveLatentVars);
    return MixedUtils.makeMixedData(ds, nd);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) Graph(edu.cmu.tetrad.graph.Graph) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 9 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class MixedUtils method main.

// main for testing
public static void main(String[] args) {
    // Graph g = GraphConverter.convert("X1-->X2,X2-->X3,X3-->X4");
    Graph g = GraphConverter.convert("X1-->X2,X2-->X3,X3-->X4, X5-->X4");
    // simple graph pm im gen example
    HashMap<String, Integer> nd = new HashMap<>();
    nd.put("X1", 0);
    nd.put("X2", 0);
    nd.put("X3", 4);
    nd.put("X4", 4);
    nd.put("X5", 0);
    g = makeMixedGraph(g, nd);
    /*Graph g = new EdgeListGraph();
        g.addNode(new ContinuousVariable("X1"));
        g.addNode(new ContinuousVariable("X2"));
        g.addNode(new DiscreteVariable("X3", 4));
        g.addNode(new DiscreteVariable("X4", 4));
        g.addNode(new ContinuousVariable("X5"));
        g.addDirectedEdge(g.getNode("X1"), g.getNode("X2"));
        g.addDirectedEdge(g.getNode("X2"), g.getNode("X3"));
        g.addDirectedEdge(g.getNode("X3"), g.getNode("X4"));
        g.addDirectedEdge(g.getNode("X4"), g.getNode("X5"));
        */
    GeneralizedSemPm pm = GaussianCategoricalPm(g, "Split(-1.5,-1,1,1.5)");
    System.out.println(pm);
    System.out.println("STARTS WITH");
    System.out.println(pm.getStartsWithParameterTemplate("C"));
    try {
        MixedUtils.setStartsWith("C", "Split(-.9,-.5,.5,.9)", pm);
    } catch (Throwable t) {
        t.printStackTrace();
    }
    System.out.println("STARTS WITH");
    System.out.println(pm.getStartsWithParameterTemplate("C"));
    System.out.println(pm);
    GeneralizedSemIm im = GaussianCategoricalIm(pm);
    System.out.println(im);
    int samps = 15;
    DataSet ds = im.simulateDataFisher(samps);
    System.out.println(ds);
    System.out.println("num cats " + ((DiscreteVariable) g.getNode("X4")).getNumCategories());
/*Graph trueGraph = DataGraphUtils.loadGraphTxt(new File(MixedUtils.class.getResource("test_data").getPath(), "DAG_0_graph.txt"));
        HashMap<String, Integer> nd = new HashMap<>();
        List<Node> nodes = trueGraph.getNodes();
        for(int i = 0; i < nodes.size(); i++){
            int coin = RandomUtil.getInstance().nextInt(2);
            int dist = (coin==0) ? 0 : 3; //continuous if coin == 0
            nd.put(nodes.get(i).getNode(), dist);
        }
        //System.out.println(getEdgeParams(g.getNode("X3"), g.getNode("X2"), pm).toString());
        //System.out.println(getEdgeParams(g.getNode("X4"), g.getNode("X3"), pm).toString());
        //System.out.println(getEdgeParams(g.getNode("X5"), g.getNode("X4"), pm).toString());
        //System.out.println("num cats " + ((DiscreteVariable) g.getNode("X4")).getNumCategories());
        /*
        HashMap<String, String> nd2 = new HashMap<>();
        nd2.put("X1", "Norm");
        nd2.put("X2", "Norm");
        nd2.put("X3", "Disc");
        nd2.put("X4", "Disc");
        nd2.put("X5", "Disc");
        GeneralizedSemPm pm2 = GaussianTrinaryPm(g, nd2, 10, "Split(-1.5,-.5,.5,1.5)");
        System.out.println("OLD pm:");
        System.out.print(pm2);
        */
}
Also used : GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 10 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark2.

// @Test
public void testClark2() {
    Node x = new ContinuousVariable("X");
    Node y = new ContinuousVariable("Y");
    Node z = new ContinuousVariable("Z");
    Graph g = new EdgeListGraph();
    g.addNode(x);
    g.addNode(y);
    g.addNode(z);
    g.addDirectedEdge(x, y);
    g.addDirectedEdge(x, z);
    g.addDirectedEdge(y, z);
    GeneralizedSemPm pm = new GeneralizedSemPm(g);
    try {
        pm.setNodeExpression(g.getNode("X"), "E_X");
        pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
        pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
        String error = "pow(Uniform(0, 1), 1.5)";
        pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
        pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
        pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
    } catch (ParseException e) {
        System.out.println(e);
    }
    GeneralizedSemIm im = new GeneralizedSemIm(pm);
    DataSet data = im.simulateData(1000, false);
    edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
    Fask fask = new Fask(data, score);
    fask.setPenaltyDiscount(1);
    fask.setAlpha(0.5);
    Graph out = fask.search();
    System.out.println(out);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Aggregations

GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)11 GeneralizedSemIm (edu.cmu.tetrad.sem.GeneralizedSemIm)7 Node (edu.cmu.tetrad.graph.Node)5 ParseException (java.text.ParseException)5 DataSet (edu.cmu.tetrad.data.DataSet)4 Graph (edu.cmu.tetrad.graph.Graph)4 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)2 Fask (edu.cmu.tetrad.search.Fask)2 SingleGraph (edu.cmu.tetrad.algcomparison.graph.SingleGraph)1 Expression (edu.cmu.tetrad.calculator.expression.Expression)1 VariableExpression (edu.cmu.tetrad.calculator.expression.VariableExpression)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1