Search in sources :

Example 1 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class GeneralSemSimulationSpecial1 method simulate.

private DataSet simulate(Graph graph, Parameters parameters) {
    GeneralizedSemPm pm = getPm(graph);
    GeneralizedSemIm im = new GeneralizedSemIm(pm);
    return im.simulateData(parameters.getInt("sampleSize"), false);
}
Also used : GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 2 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class GeneralSemSimulationSpecial1 method getPm.

private GeneralizedSemPm getPm(Graph graph) {
    GeneralizedSemPm pm = new GeneralizedSemPm(graph);
    List<Node> variablesNodes = pm.getVariableNodes();
    List<Node> errorNodes = pm.getErrorNodes();
    Map<String, String> paramMap = new HashMap<>();
    String[] funcs = { "TSUM(NEW(B)*$)", "TSUM(NEW(B)*$+NEW(C)*sin(NEW(T)*$+NEW(A)))", "TSUM(NEW(B)*(.5*$ + .5*(sqrt(abs(NEW(b)*$+NEW(exoErrorType))) ) ) )" };
    paramMap.put("s", "U(1,3)");
    paramMap.put("B", "Split(-1.5,-.5,.5,1.5)");
    paramMap.put("C", "Split(-1.5,-.5,.5,1.5)");
    paramMap.put("T", "U(.5,1.5)");
    paramMap.put("A", "U(0,.25)");
    paramMap.put("exoErrorType", "U(-.5,.5)");
    paramMap.put("funcType", "U(1,5)");
    String nonlinearStructuralEdgesFunction = funcs[0];
    String nonlinearFactorMeasureEdgesFunction = funcs[0];
    try {
        for (Node node : variablesNodes) {
            if (node.getNodeType() == NodeType.LATENT) {
                String _template = TemplateExpander.getInstance().expandTemplate(nonlinearStructuralEdgesFunction, pm, node);
                pm.setNodeExpression(node, _template);
            } else {
                String _template = TemplateExpander.getInstance().expandTemplate(nonlinearFactorMeasureEdgesFunction, pm, node);
                pm.setNodeExpression(node, _template);
            }
        }
        for (Node node : errorNodes) {
            String _template = TemplateExpander.getInstance().expandTemplate("U(-.5,.5)", pm, node);
            pm.setNodeExpression(node, _template);
        }
        Set<String> parameters = pm.getParameters();
        for (String parameter : parameters) {
            for (String type : paramMap.keySet()) {
                if (parameter.startsWith(type)) {
                    pm.setParameterExpression(parameter, paramMap.get(type));
                }
            }
        }
    } catch (ParseException e) {
        System.out.println(e);
    }
    return pm;
}
Also used : Node(edu.cmu.tetrad.graph.Node) ParseException(java.text.ParseException) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 3 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class MixedUtils method GaussianCategoricalPm.

// generate PM from trueGraph for mixed Gaussian and Categorical variables
// public static GeneralizedSemPm GaussianCategoricalPm(Graph trueGraph, HashMap<String, Integer> nodeDists, String paramTemplate) throws IllegalStateException{
public static GeneralizedSemPm GaussianCategoricalPm(Graph trueGraph, String paramTemplate) throws IllegalStateException {
    Map<String, Integer> nodeDists = getNodeDists(trueGraph);
    GeneralizedSemPm semPm = new GeneralizedSemPm(trueGraph);
    try {
        List<Node> variableNodes = semPm.getVariableNodes();
        int numVars = variableNodes.size();
        semPm.setStartsWithParametersTemplate("B", paramTemplate);
        semPm.setStartsWithParametersTemplate("C", paramTemplate);
        semPm.setStartsWithParametersTemplate("D", paramTemplate);
        // empirically should give us a stddev of 1 - 2
        semPm.setStartsWithParametersTemplate("s", "U(1,2)");
        // if we don't use NB error, we could do this instead
        // String templateDisc = "DiscError(err, (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)))";
        // String templateDisc0 = "DiscError(err, 1,1,1)";
        String templateDisc0 = "DiscError(err, ";
        for (Node node : variableNodes) {
            List<Node> parents = trueGraph.getParents(node);
            // System.out.println("nParents: " + parents.size() );
            Node eNode = semPm.getErrorNode(node);
            // normal and nb work like normal sems
            String curEx = semPm.getNodeExpressionString(node);
            String errEx = semPm.getNodeExpressionString(eNode);
            String newTemp = "";
            // System.out.println("Node: " + node + "Type: " + nodeDists.get(node));
            // dist of 0 means Gaussian
            int curDist = nodeDists.get(node.getName());
            if (curDist == 1)
                throw new IllegalArgumentException("Dist for node " + node.getName() + " is set to one (i.e. constant) which is not supported.");
            // for each discrete node use DiscError for categorical draw
            if (curDist > 0) {
                if (parents.size() == 0) {
                    newTemp = "DiscError(err";
                    for (int l = 0; l < curDist; l++) {
                        newTemp += ",1";
                    }
                    newTemp += ")";
                // newTemp = templateDisc0;
                } else {
                    newTemp = "DiscError(err";
                    for (int l = 0; l < curDist; l++) {
                        newTemp += ", TSUM(NEW(C)*$)";
                    }
                    newTemp += ")";
                }
                newTemp = newTemp.replaceAll("err", eNode.getName());
                curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
                // System.out.println("Disc CurEx: " + curEx);
                errEx = TemplateExpander.getInstance().expandTemplate("U(0,1)", semPm, eNode);
            }
            // now for every discrete parent, swap for discrete params
            newTemp = curEx;
            if (parents.size() != 0) {
                for (Node parNode : parents) {
                    int parDist = nodeDists.get(parNode.getName());
                    if (parDist > 0) {
                        // String curName = trueGraph.getParents(node).get(0).toString();
                        String curName = parNode.getName();
                        String disRep = "Switch(" + curName;
                        for (int l = 0; l < parDist; l++) {
                            if (curDist > 0) {
                                disRep += ",NEW(D)";
                            } else {
                                disRep += ",NEW(C)";
                            }
                        }
                        disRep += ")";
                        // replaces BX * curName with new discrete expression
                        if (curDist > 0) {
                            newTemp = newTemp.replaceAll("(C[0-9]*\\*" + curName + ")(?![0-9])", disRep);
                        } else {
                            newTemp = newTemp.replaceAll("(B[0-9]*\\*" + curName + ")(?![0-9])", disRep);
                        }
                    }
                }
            }
            if (newTemp.length() != 0) {
                // System.out.println(newTemp);
                curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
            }
            semPm.setNodeExpression(node, curEx);
            semPm.setNodeExpression(eNode, errEx);
        }
    } catch (ParseException e) {
        throw new IllegalStateException("Parse error in fixing parameters.", e);
    }
    return semPm;
}
Also used : ParseException(java.text.ParseException) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 4 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class MixedUtils method GaussianTrinaryPm.

// generate PM from trueGraph for mixed Gaussian and Trinary variables
// Don't use, buggy
public static GeneralizedSemPm GaussianTrinaryPm(Graph trueGraph, HashMap<String, String> nodeDists, int maxSample, String paramTemplate) throws IllegalStateException {
    GeneralizedSemPm semPm = new GeneralizedSemPm(trueGraph);
    try {
        List<Node> variableNodes = semPm.getVariableNodes();
        int numVars = variableNodes.size();
        semPm.setStartsWithParametersTemplate("B", paramTemplate);
        semPm.setStartsWithParametersTemplate("D", paramTemplate);
        // empirically should give us a stddev of 1 - 2
        semPm.setStartsWithParametersTemplate("al", "U(.3,1.3)");
        semPm.setStartsWithParametersTemplate("s", "U(1,2)");
        // if we don't use NB error, we could do this instead
        String templateDisc = "DiscError(err, (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)))";
        // String templateDisc0 = "DiscError(err, 2,2,2)";
        String templateDisc0 = "DiscError(err, .001,.001,.001)";
        for (Node node : variableNodes) {
            List<Node> parents = trueGraph.getParents(node);
            // System.out.println("nParents: " + parents.size() );
            Node eNode = semPm.getErrorNode(node);
            // normal and nb work like normal sems
            String curEx = semPm.getNodeExpressionString(node);
            String errEx = semPm.getNodeExpressionString(eNode);
            String newTemp = "";
            if (nodeDists.get(node.getName()).equals("Disc")) {
                if (parents.size() == 0) {
                    newTemp = templateDisc0;
                } else {
                    newTemp = templateDisc;
                }
                newTemp = newTemp.replaceAll("err", eNode.getName());
                curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
                // System.out.println("Disc CurEx: " + curEx);
                errEx = TemplateExpander.getInstance().expandTemplate("U(0,1)", semPm, eNode);
            }
            // now for every discrete parent, swap for discrete params
            newTemp = "";
            if (parents.size() != 0) {
                for (Node parNode : parents) {
                    if (nodeDists.get(parNode.getName()).equals("Disc")) {
                        // String curName = trueGraph.getParents(node).get(0).toString();
                        String curName = parNode.getName();
                        String disRep = "IF(" + curName + "=0,NEW(D),IF(" + curName + "=1,NEW(D),NEW(D)))";
                        newTemp = curEx.replaceAll("(B[0-9]*\\*" + curName + ")(?![0-9])", disRep);
                    }
                }
            }
            if (newTemp.length() != 0) {
                curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
            }
            semPm.setNodeExpression(node, curEx);
            semPm.setNodeExpression(eNode, errEx);
        }
    } catch (ParseException e) {
        throw new IllegalStateException("Parse error in fixing parameters.", e);
    }
    return semPm;
}
Also used : ParseException(java.text.ParseException) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Example 5 with GeneralizedSemPm

use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.

the class MGM method runTests2.

/**
 * test non penalty use cases
 */
private static void runTests2() {
    Graph g = GraphConverter.convert("X1-->X2,X3-->X2,X4-->X5");
    // simple graph pm im gen example
    HashMap<String, Integer> nd = new HashMap<>();
    nd.put("X1", 0);
    nd.put("X2", 0);
    nd.put("X3", 4);
    nd.put("X4", 4);
    nd.put("X5", 4);
    g = MixedUtils.makeMixedGraph(g, nd);
    GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
    System.out.println(pm);
    GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
    System.out.println(im);
    int samps = 1000;
    DataSet ds = im.simulateDataFisher(samps);
    ds = MixedUtils.makeMixedData(ds, nd);
    // System.out.println(ds);
    double lambda = 0;
    MGM model = new MGM(ds, new double[] { lambda, lambda, lambda });
    System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
    System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
    model.learn(1e-8, 1000);
    System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
    System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
    System.out.println("params:\n" + model.params);
    System.out.println("adjMat:\n" + model.adjMatFromMGM());
}
Also used : HashMap(java.util.HashMap) DataSet(edu.cmu.tetrad.data.DataSet) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm)

Aggregations

GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)11 GeneralizedSemIm (edu.cmu.tetrad.sem.GeneralizedSemIm)7 Node (edu.cmu.tetrad.graph.Node)5 ParseException (java.text.ParseException)5 DataSet (edu.cmu.tetrad.data.DataSet)4 Graph (edu.cmu.tetrad.graph.Graph)4 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)2 Fask (edu.cmu.tetrad.search.Fask)2 SingleGraph (edu.cmu.tetrad.algcomparison.graph.SingleGraph)1 Expression (edu.cmu.tetrad.calculator.expression.Expression)1 VariableExpression (edu.cmu.tetrad.calculator.expression.VariableExpression)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1