Search in sources :

Example 11 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark.

// @Test
public void testClark() {
    double f = .1;
    int N = 512;
    double alpha = 1.0;
    double penaltyDiscount = 1.0;
    for (int i = 0; i < 100; i++) {
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(z, x);
            g.addDirectedEdge(z, y);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(x, z);
            g.addDirectedEdge(y, z);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 12 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class TestSemEvidence method constructGraph1.

private Graph constructGraph1() {
    Graph graph = new EdgeListGraph();
    Node x1 = new GraphNode("X1");
    Node x2 = new GraphNode("X2");
    Node x3 = new GraphNode("X3");
    Node x4 = new GraphNode("X4");
    Node x5 = new GraphNode("X5");
    graph.addNode(x1);
    graph.addNode(x2);
    graph.addNode(x3);
    graph.addNode(x4);
    graph.addNode(x5);
    graph.addDirectedEdge(x1, x2);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x1, x4);
    graph.addDirectedEdge(x4, x5);
    return graph;
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) GraphNode(edu.cmu.tetrad.graph.GraphNode)

Example 13 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class LogisticRegressionRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    outGraph = new EdgeListGraph();
    if (regressorNames == null || regressorNames.isEmpty() || targetName == null) {
        report = "Response and predictor variables not set.";
        return;
    }
    if (regressorNames.contains(targetName)) {
        report = "Response must not be a predictor.";
        return;
    }
    DataSet regressorsDataSet = dataSets.get(getModelIndex()).copy();
    Node target = regressorsDataSet.getVariable(targetName);
    regressorsDataSet.removeColumn(target);
    List<String> names = regressorsDataSet.getVariableNames();
    // Get the list of regressors selected by the user
    List<Node> regressorNodes = new ArrayList<>();
    for (String s : regressorNames) {
        regressorNodes.add(dataSets.get(getModelIndex()).getVariable(s));
    }
    // If the user selected none, use them all
    if (regressorNames.size() > 0) {
        for (String name1 : names) {
            Node regressorVar = regressorsDataSet.getVariable(name1);
            if (!regressorNames.contains(regressorVar.getName())) {
                regressorsDataSet.removeColumn(regressorVar);
            }
        }
    }
    int ncases = regressorsDataSet.getNumRows();
    int nvars = regressorsDataSet.getNumColumns();
    double[][] regressors = new double[nvars][ncases];
    for (int i = 0; i < nvars; i++) {
        for (int j = 0; j < ncases; j++) {
            regressors[i][j] = regressorsDataSet.getDouble(j, i);
        }
    }
    LogisticRegression logRegression = new LogisticRegression(dataSets.get(getModelIndex()));
    logRegression.setAlpha(alpha);
    this.result = logRegression.regress((DiscreteVariable) target, regressorNodes);
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ArrayList(java.util.ArrayList) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression)

Example 14 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class BayesProperties method getLikelihoodRatioP.

/**
 * Calculates the p-value of the graph with respect to the given data.
 */
public final double getLikelihoodRatioP(Graph graph) {
    // Null hypothesis = complete graph.
    List<Node> nodes = graph.getNodes();
    Graph graph0 = new EdgeListGraph(nodes);
    for (int i = 0; i < nodes.size(); i++) {
        for (int j = i + 1; j < nodes.size(); j++) graph0.addDirectedEdge(nodes.get(i), nodes.get(j));
    }
    Ret r0 = getLikelihood2(graph0);
    Ret r1 = getLikelihood2(graph);
    this.likelihood = r1.getLik();
    double lDiff = r0.getLik() - r1.getLik();
    System.out.println("lDiff = " + lDiff);
    int nDiff = r0.getDof() - r1.getDof();
    System.out.println("nDiff = " + nDiff);
    double chisq = 2.0 * lDiff;
    double dof = nDiff;
    this.chisq = chisq;
    this.dof = dof;
    int N = dataSet.getNumRows();
    this.bic = 2 * r1.getLik() - r1.getDof() * Math.log(N);
    System.out.println("bic = " + bic);
    System.out.println("chisq = " + chisq);
    System.out.println("dof = " + dof);
    double p = 1.0 - new ChiSquaredDistribution(dof).cumulativeProbability(chisq);
    System.out.println("p = " + p);
    return p;
}
Also used : ChiSquaredDistribution(org.apache.commons.math3.distribution.ChiSquaredDistribution) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph)

Example 15 with EdgeListGraph

use of edu.cmu.tetrad.graph.EdgeListGraph in project tetrad by cmu-phil.

the class SimulationEditor method resetPanel.

private void resetPanel(Simulation simulation, String[] graphItems, String[] simulationItems, JTabbedPane tabbedPane) {
    RandomGraph randomGraph = (simulation.getSourceGraph() == null) ? new SingleGraph(new EdgeListGraph()) : new SingleGraph(simulation.getSourceGraph());
    if (!simulation.isFixedGraph()) {
        String graphItem = (String) graphsDropdown.getSelectedItem();
        simulation.getParams().set("graphsDropdownPreference", graphItem);
        if (graphItem.equals(graphItems[0])) {
            randomGraph = new RandomForward();
        } else if (graphItem.equals(graphItems[1])) {
            randomGraph = new ScaleFree();
        } else if (graphItem.equals(graphItems[2])) {
            randomGraph = new Cyclic();
        } else if (graphItem.equals(graphItems[3])) {
            randomGraph = new RandomSingleFactorMim();
        } else if (graphItem.equals(graphItems[4])) {
            randomGraph = new RandomTwoFactorMim();
        } else {
            throw new IllegalArgumentException("Unrecognized simulation type: " + graphItem);
        }
    }
    if (!simulation.isFixedSimulation()) {
        if (simulation.getSourceGraph() == null) {
            String simulationItem = (String) simulationsDropdown.getSelectedItem();
            simulation.getParams().set("simulationsDropdownPreference", simulationItem);
            simulation.setFixedGraph(false);
            if (randomGraph instanceof SingleGraph) {
                simulation.setFixedGraph(true);
            }
            if (simulationItem.equals(simulationItems[0])) {
                simulation.setSimulation(new BayesNetSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[1])) {
                simulation.setSimulation(new SemSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[2])) {
                simulation.setSimulation(new LinearFisherModel(randomGraph, simulation.getInputDataModelList()), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[3])) {
                simulation.setSimulation(new LeeHastieSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[4])) {
                simulation.setSimulation(new ConditionalGaussianSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[5])) {
                simulation.setSimulation(new TimeSeriesSemSimulation(randomGraph), simulation.getParams());
            } else {
                throw new IllegalArgumentException("Unrecognized simulation type: " + simulationItem);
            }
        } else {
            String simulationItem = (String) simulationsDropdown.getSelectedItem();
            simulation.getParams().set("simulationsDropdownPreference", simulationItem);
            simulation.setFixedGraph(false);
            if (randomGraph instanceof SingleGraph) {
                simulation.setFixedGraph(true);
            }
            if (simulationItem.equals(simulationItems[0])) {
                simulation.setSimulation(new BayesNetSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[1])) {
                simulation.setSimulation(new SemSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[2])) {
                simulation.setSimulation(new LinearFisherModel(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[3])) {
                simulation.setSimulation(new LeeHastieSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[4])) {
                simulation.setSimulation(new ConditionalGaussianSimulation(randomGraph), simulation.getParams());
            } else if (simulationItem.equals(simulationItems[5])) {
                simulation.setSimulation(new TimeSeriesSemSimulation(randomGraph), simulation.getParams());
            } else {
                throw new IllegalArgumentException("Unrecognized simulation type: " + simulationItem);
            }
        }
    }
    tabbedPane.setComponentAt(0, new PaddingPanel(getParameterPanel(simulation, simulation.getSimulation(), simulation.getParams())));
}
Also used : LinearFisherModel(edu.cmu.tetrad.algcomparison.simulation.LinearFisherModel) RandomSingleFactorMim(edu.cmu.tetrad.algcomparison.graph.RandomSingleFactorMim) TimeSeriesSemSimulation(edu.cmu.tetrad.algcomparison.simulation.TimeSeriesSemSimulation) GeneralSemSimulation(edu.cmu.tetrad.algcomparison.simulation.GeneralSemSimulation) SemSimulation(edu.cmu.tetrad.algcomparison.simulation.SemSimulation) StandardizedSemSimulation(edu.cmu.tetrad.algcomparison.simulation.StandardizedSemSimulation) RandomForward(edu.cmu.tetrad.algcomparison.graph.RandomForward) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) BayesNetSimulation(edu.cmu.tetrad.algcomparison.simulation.BayesNetSimulation) ScaleFree(edu.cmu.tetrad.algcomparison.graph.ScaleFree) TimeSeriesSemSimulation(edu.cmu.tetrad.algcomparison.simulation.TimeSeriesSemSimulation) PaddingPanel(edu.cmu.tetradapp.ui.PaddingPanel) RandomTwoFactorMim(edu.cmu.tetrad.algcomparison.graph.RandomTwoFactorMim) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) Cyclic(edu.cmu.tetrad.algcomparison.graph.Cyclic) LeeHastieSimulation(edu.cmu.tetrad.algcomparison.simulation.LeeHastieSimulation) ConditionalGaussianSimulation(edu.cmu.tetrad.algcomparison.simulation.ConditionalGaussianSimulation) SingleGraph(edu.cmu.tetrad.algcomparison.graph.SingleGraph)

Aggregations

EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)46 Graph (edu.cmu.tetrad.graph.Graph)36 Node (edu.cmu.tetrad.graph.Node)31 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)11 DataSet (edu.cmu.tetrad.data.DataSet)10 Edge (edu.cmu.tetrad.graph.Edge)8 GraphNode (edu.cmu.tetrad.graph.GraphNode)8 ArrayList (java.util.ArrayList)8 DMSearch (edu.cmu.tetrad.search.DMSearch)6 Test (org.junit.Test)6 SemIm (edu.cmu.tetrad.sem.SemIm)5 SemPm (edu.cmu.tetrad.sem.SemPm)5 List (java.util.List)5 IOException (java.io.IOException)4 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)3 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)3 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)2 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)2