Search in sources :

Example 76 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class TestLogisticRegression method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
    System.out.println(graph);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateDataRecursive(1000, false);
    Node x1 = data.getVariable("X1");
    Node x2 = data.getVariable("X2");
    Node x3 = data.getVariable("X3");
    Node x4 = data.getVariable("X4");
    Node x5 = data.getVariable("X5");
    Discretizer discretizer = new Discretizer(data);
    discretizer.equalCounts(x1, 2);
    DataSet d2 = discretizer.discretize();
    LogisticRegression regression = new LogisticRegression(d2);
    List<Node> regressors = new ArrayList<>();
    regressors.add(x2);
    regressors.add(x3);
    regressors.add(x4);
    regressors.add(x5);
    DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
    regression.regress(x1b, regressors);
    System.out.println(regression);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 77 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class BayesUpdaterClassifier method classify.

/**
 * Computes and returns the crosstabulation of observed versus estimated
 * values of the target variable as described above.
 */
public int[] classify() {
    if (targetVariable == null) {
        throw new NullPointerException("Target not set.");
    }
    // Create an updater for the instantiated Bayes net.
    BayesUpdater bayesUpdater = new RowSummingExactUpdater(getBayesIm());
    // Get the raw data from the dataset to be classified, the number
    // of variables and the number of cases.
    int nvars = getBayesImVars().size();
    int ncases = testData.getNumRows();
    int[] varIndices = new int[nvars];
    List<Node> dataVars = testData.getVariables();
    for (int i = 0; i < nvars; i++) {
        DiscreteVariable variable = (DiscreteVariable) getBayesImVars().get(i);
        if (variable == targetVariable) {
            continue;
        }
        varIndices[i] = dataVars.indexOf(variable);
        if (varIndices[i] == -1) {
            throw new IllegalArgumentException("Can't find the (non-target) variable " + variable + " in the data. Either it's not there, or else its " + "categories are in a different order.");
        }
    }
    DataSet selectedData = testData.subsetColumns(varIndices);
    this.numCases = ncases;
    int[] estimatedValues = new int[ncases];
    int numTargetCategories = targetVariable.getNumCategories();
    double[][] probOfClassifiedValues = new double[numTargetCategories][ncases];
    Arrays.fill(estimatedValues, -1);
    // and Bayesian updating.
    for (int i = 0; i < ncases; i++) {
        // Create an Evidence instance for the instantiated Bayes net
        // which will allow that updating.
        Evidence evidence = Evidence.tautology(getBayesIm());
        // Let the target variable range over all its values.
        int itarget = evidence.getNodeIndex(targetVariable.getName());
        evidence.getProposition().setVariable(itarget, true);
        this.missingValueCaseFound = false;
        // this case.
        for (int j = 0; j < getBayesImVars().size(); j++) {
            if (j == getBayesImVars().indexOf(targetVariable)) {
                continue;
            }
            int observedValue = selectedData.getInt(i, j);
            if (observedValue == DiscreteVariable.MISSING_VALUE) {
                this.missingValueCaseFound = true;
                continue;
            }
            String jName = getBayesImVars().get(j).getName();
            int jIndex = evidence.getNodeIndex(jName);
            evidence.getProposition().setCategory(jIndex, observedValue);
        }
        // Update using those values.
        bayesUpdater.setEvidence(evidence);
        // for each possible value of target compute its probability in
        // the updated Bayes net.  Select the value with the highest
        // probability as the estimated value.
        Node targetNode = getBayesIm().getNode(targetVariable.getName());
        int indexTargetBN = getBayesIm().getNodeIndex(targetNode);
        // Straw man values--to be replaced.
        int estimatedValue = -1;
        // if (numTargetCategories == 2) {
        // for (int j = 0; j < numTargetCategories; j++) {
        // double marginal =
        // bayesUpdater.getMarginal(indexTargetBN, j);
        // probOfClassifiedValues[j][i] = marginal;
        // probOfClassifiedValues[1 - j][i] = 1.0 - marginal;
        // 
        // if (targetCategory == j) {
        // if (marginal > binaryCutoff) {
        // estimatedValue = j;
        // } else {
        // estimatedValue = 1 - j;
        // }
        // 
        // break;
        // }
        // }
        // } else
        {
            double highestProb = -0.1;
            for (int j = 0; j < numTargetCategories; j++) {
                double marginal = bayesUpdater.getMarginal(indexTargetBN, j);
                probOfClassifiedValues[j][i] = marginal;
                if (marginal >= highestProb) {
                    highestProb = marginal;
                    estimatedValue = j;
                }
            }
        }
        // training dataset.  If that happens skip the case.
        if (estimatedValue < 0) {
            TetradLogger.getInstance().log("details", "Case " + i + " does not return valid marginal.");
            for (int m = 0; m < nvars; m++) {
                // System.out.print(getBayesImVars()
                // .get(m).getNode());
                TetradLogger.getInstance().log("details", "  " + selectedData.getDouble(i, m));
            }
            estimatedValues[i] = DiscreteVariable.MISSING_VALUE;
            continue;
        }
        estimatedValues[i] = estimatedValue;
    }
    this.classifications = estimatedValues;
    this.marginals = probOfClassifiedValues;
    return estimatedValues;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 78 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansCholesky.

@Test
public void testMeansCholesky() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataCholesky(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.6);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 79 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansReducedForm.

@Test
public void testMeansReducedForm() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataReducedForm(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.5);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 80 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class TestSemVarMeans method testMeansRecursive.

@Test
public void testMeansRecursive() {
    Graph graph = constructGraph1();
    SemPm semPm1 = new SemPm(graph);
    List<Parameter> parameters = semPm1.getParameters();
    for (Parameter p : parameters) {
        p.setInitializedRandomly(false);
    }
    SemIm semIm1 = new SemIm(semPm1);
    double[] means = { 5.0, 4.0, 3.0, 2.0, 1.0 };
    RandomUtil.getInstance().setSeed(-379467L);
    for (int i = 0; i < semIm1.getVariableNodes().size(); i++) {
        Node node = semIm1.getVariableNodes().get(i);
        semIm1.setMean(node, means[i]);
    }
    DataSet dataSet = semIm1.simulateDataRecursive(1000, false);
    SemEstimator semEst = new SemEstimator(dataSet, semPm1);
    semEst.estimate();
    SemIm estSemIm = semEst.getEstimatedSem();
    List<Node> nodes = semPm1.getVariableNodes();
    for (Node node : nodes) {
        double mean = semIm1.getMean(node);
        assertEquals(mean, estSemIm.getMean(node), 0.5);
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Parameter(edu.cmu.tetrad.sem.Parameter) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Aggregations

DataSet (edu.cmu.tetrad.data.DataSet)216 Test (org.junit.Test)65 Graph (edu.cmu.tetrad.graph.Graph)64 Node (edu.cmu.tetrad.graph.Node)60 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)48 ArrayList (java.util.ArrayList)45 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)36 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)32 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)29 SemIm (edu.cmu.tetrad.sem.SemIm)28 SemPm (edu.cmu.tetrad.sem.SemPm)28 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)26 DataModel (edu.cmu.tetrad.data.DataModel)22 Parameters (edu.cmu.tetrad.util.Parameters)22 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)20 File (java.io.File)16 ParseException (java.text.ParseException)16 LinkedList (java.util.LinkedList)14 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)13 DMSearch (edu.cmu.tetrad.search.DMSearch)10