Search in sources :

Example 1 with GeneralizedSemIm

use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.

the class GeneralSemSimulationSpecial1 method simulate.

private DataSet simulate(Graph graph, Parameters parameters) {
    GeneralizedSemPm pm = getPm(graph);
    GeneralizedSemIm im = new GeneralizedSemIm(pm);
    return im.simulateData(parameters.getInt("sampleSize"), false);
}
Also used : GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 2 with GeneralizedSemIm

use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.

the class MixedUtils method GaussianCategoricalIm.

/**
 *    This method is needed to normalize edge parameters for an Instantiated Mixed Model
 *    Generates edge parameters for c-d and d-d edges from a single weight, abs(w), drawn by the normal IM constructor.
 *    Abs(w) is used for d-d edges.
 *
 *    For deterministic, c-d are evenly spaced between -w and w, and d-d are a matrix with w on the diagonal and
 *    -w/(categories-1) in the rest.
 *    For random, c-d params are uniformly drawn from 0 to 1 then transformed to have w as max value and sum to 0.
 *
 * @param pm
 * @param discParamRand true for random edge generation behavior, false for deterministic
 * @return
 */
public static GeneralizedSemIm GaussianCategoricalIm(GeneralizedSemPm pm, boolean discParamRand) {
    Map<String, Integer> nodeDists = getNodeDists(pm.getGraph());
    GeneralizedSemIm im = new GeneralizedSemIm(pm);
    // System.out.println(im);
    List<Node> nodes = pm.getVariableNodes();
    // this needs to be changed for cyclic graphs...
    for (Node n : nodes) {
        Set<Node> parNodes = pm.getReferencedNodes(n);
        if (parNodes.size() == 0) {
            continue;
        }
        for (Node par : parNodes) {
            if (par.getNodeType() == NodeType.ERROR) {
                continue;
            }
            int cL = nodeDists.get(n.getName());
            int pL = nodeDists.get(par.getName());
            // c-c edges don't need params changed
            if (cL == 0 && pL == 0) {
                continue;
            }
            List<String> params = getEdgeParams(n, par, pm);
            // just use the first parameter as the "weight" for the whole edge
            double w = im.getParameterValue(params.get(0));
            // d-d edges use one vector and permute edges, could use different strategy
            if (cL > 0 && pL > 0) {
                double[][] newWeights = new double[cL][pL];
                // List<Integer> indices = new ArrayList<Integer>(pL);
                // PermutationGenerator pg = new PermutationGenerator(pL);
                // int[] permInd = pg.next();
                w = Math.abs(w);
                double bgW = w / ((double) pL - 1.0);
                double[] weightVals;
                /*if(discParamRand)
                        weightVals = generateMixedEdgeParams(w, pL);
                    else
                        weightVals = evenSplitVector(w, pL);
                    */
                int[] weightInds = new int[cL];
                for (int i = 0; i < cL; i++) {
                    if (i < pL)
                        weightInds[i] = i;
                    else
                        weightInds[i] = i % pL;
                }
                if (discParamRand)
                    weightInds = arrayPermute(weightInds);
                for (int i = 0; i < cL; i++) {
                    for (int j = 0; j < pL; j++) {
                        int index = i * pL + j;
                        if (weightInds[i] == j)
                            im.setParameterValue(params.get(index), w);
                        else
                            im.setParameterValue(params.get(index), -bgW);
                    }
                }
            // params for c-d edges
            } else {
                double[] newWeights;
                int curL = (pL > 0 ? pL : cL);
                if (discParamRand)
                    newWeights = generateMixedEdgeParams(w, curL);
                else
                    newWeights = evenSplitVector(w, curL);
                int count = 0;
                for (String p : params) {
                    im.setParameterValue(p, newWeights[count]);
                    count++;
                }
            }
        }
    // pm.
    // if(p.startsWith("B")){
    // continue;
    // } else if(p.startsWith())
    }
    return im;
}
Also used : GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm)

Example 3 with GeneralizedSemIm

use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.

the class MGM method runTests2.

/**
 * test non penalty use cases
 */
private static void runTests2() {
    Graph g = GraphConverter.convert("X1-->X2,X3-->X2,X4-->X5");
    // simple graph pm im gen example
    HashMap<String, Integer> nd = new HashMap<>();
    nd.put("X1", 0);
    nd.put("X2", 0);
    nd.put("X3", 4);
    nd.put("X4", 4);
    nd.put("X5", 4);
    g = MixedUtils.makeMixedGraph(g, nd);
    GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
    System.out.println(pm);
    GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
    System.out.println(im);
    int samps = 1000;
    DataSet ds = im.simulateDataFisher(samps);
    ds = MixedUtils.makeMixedData(ds, nd);
    // System.out.println(ds);
    double lambda = 0;
    MGM model = new MGM(ds, new double[] { lambda, lambda, lambda });
    System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
    System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
    model.learn(1e-8, 1000);
    System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
    System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
    System.out.println("params:\n" + model.params);
    System.out.println("adjMat:\n" + model.adjMatFromMGM());
}
Also used : HashMap(java.util.HashMap) DataSet(edu.cmu.tetrad.data.DataSet) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 4 with GeneralizedSemIm

use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark.

// @Test
public void testClark() {
    double f = .1;
    int N = 512;
    double alpha = 1.0;
    double penaltyDiscount = 1.0;
    for (int i = 0; i < 100; i++) {
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(z, x);
            g.addDirectedEdge(z, y);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(x, z);
            g.addDirectedEdge(y, z);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 5 with GeneralizedSemIm

use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.

the class LeeHastieSimulation method simulate.

private DataSet simulate(Graph dag, Parameters parameters) {
    HashMap<String, Integer> nd = new HashMap<>();
    List<Node> nodes = dag.getNodes();
    Collections.shuffle(nodes);
    if (this.shuffledOrder == null) {
        List<Node> shuffledNodes = new ArrayList<>(nodes);
        Collections.shuffle(shuffledNodes);
        this.shuffledOrder = shuffledNodes;
    }
    for (int i = 0; i < nodes.size(); i++) {
        if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
            final int minNumCategories = parameters.getInt("numCategories");
            final int maxNumCategories = parameters.getInt("numCategories");
            final int value = pickNumCategories(minNumCategories, maxNumCategories);
            nd.put(shuffledOrder.get(i).getName(), value);
        } else {
            nd.put(shuffledOrder.get(i).getName(), 0);
        }
    }
    Graph graph = MixedUtils.makeMixedGraph(dag, nd);
    GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(graph, "Split(-1.5,-.5,.5,1.5)");
    GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    DataSet ds = im.simulateDataAvoidInfinity(parameters.getInt("sampleSize"), saveLatentVars);
    return MixedUtils.makeMixedData(ds, nd);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) Graph(edu.cmu.tetrad.graph.Graph) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Aggregations

GeneralizedSemIm (edu.cmu.tetrad.sem.GeneralizedSemIm)8 GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)7 DataSet (edu.cmu.tetrad.data.DataSet)4 Graph (edu.cmu.tetrad.graph.Graph)4 Node (edu.cmu.tetrad.graph.Node)4 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)2 Fask (edu.cmu.tetrad.search.Fask)2 ParseException (java.text.ParseException)2 SingleGraph (edu.cmu.tetrad.algcomparison.graph.SingleGraph)1 HashMap (java.util.HashMap)1