Search in sources :

Example 11 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class AbstractMBSearchRunner method setSearchResults.

/**
 * Sets the results of the search.
 */
void setSearchResults(List<Node> nodes) {
    if (nodes == null) {
        throw new NullPointerException("nodes were null.");
    }
    this.variables = new ArrayList<>(nodes);
    if (nodes.isEmpty()) {
        this.dataModel = new ColtDataSet(source.getNumRows(), nodes);
    } else {
        this.dataModel = this.source.subsetColumns(nodes);
    }
    this.setDataModel(this.dataModel);
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Example 12 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TabularComparison method newExecution.

private void newExecution() {
    statistics = new ArrayList<>();
    statistics.add(new AdjacencyPrecision());
    statistics.add(new AdjacencyRecall());
    statistics.add(new ArrowheadPrecision());
    statistics.add(new ArrowheadRecall());
    statistics.add(new TwoCyclePrecision());
    statistics.add(new TwoCycleRecall());
    statistics.add(new TwoCycleFalsePositive());
    // statistics.add(new ElapsedTime());
    // statistics.add(new F1Adj());
    // statistics.add(new F1Arrow());
    // statistics.add(new MathewsCorrAdj());
    // statistics.add(new MathewsCorrArrow());
    // statistics.add(new SHD());
    List<Node> variables = new ArrayList<>();
    for (Statistic statistic : statistics) {
        variables.add(new ContinuousVariable(statistic.getAbbreviation()));
    }
    dataSet = new ColtDataSet(0, variables);
    dataSet.setNumberFormat(new DecimalFormat("0.00"));
}
Also used : Node(edu.cmu.tetrad.graph.Node) DecimalFormat(java.text.DecimalFormat) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Example 13 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class RemoveMissingCasesDataFilter method filter.

public DataSet filter(DataSet data) {
    List<Node> variables = data.getVariables();
    int numRows = 0;
    ROWS: for (int row = 0; row < data.getNumRows(); row++) {
        for (int col = 0; col < data.getNumColumns(); col++) {
            Node variable = data.getVariable(col);
            if (((Variable) variable).isMissingValue(data.getObject(row, col))) {
                continue ROWS;
            }
        }
        numRows++;
    }
    DataSet newDataSet = new ColtDataSet(numRows, variables);
    int newRow = 0;
    ROWS: for (int row = 0; row < data.getNumRows(); row++) {
        for (int col = 0; col < data.getNumColumns(); col++) {
            Node variable = data.getVariable(col);
            if (((Variable) variable).isMissingValue(data.getObject(row, col))) {
                continue ROWS;
            }
        }
        for (int col = 0; col < data.getNumColumns(); col++) {
            newDataSet.setObject(newRow, col, data.getObject(row, col));
        }
        newRow++;
    }
    return newDataSet;
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node)

Example 14 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class EmBayesEstimator method initialize.

private void initialize() {
    DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5);
    observedIm = DirichletEstimator.estimate(prior, dataSet);
    // MLBayesEstimator dirichEst = new MLBayesEstimator();
    // observedIm = dirichEst.estimate(bayesPmObs, dataSet);
    // System.out.println("Estimated Bayes IM for Measured Variables:  ");
    // System.out.println(observedIm);
    // mixedData should be ddsNm with new columns for the latent variables.
    // Each such column should contain missing data for each case.
    int numFullCases = dataSet.getNumRows();
    List<Node> variables = new LinkedList<>();
    for (Node node : nodes) {
        if (node.getNodeType() == NodeType.LATENT) {
            int numCategories = bayesPm.getNumCategories(node);
            DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
            latentVar.setNodeType(NodeType.LATENT);
            variables.add(latentVar);
        } else {
            String name = bayesPm.getVariable(node).getName();
            Node variable = dataSet.getVariable(name);
            variables.add(variable);
        }
    }
    DataSet dsMixed = new ColtDataSet(numFullCases, variables);
    for (int j = 0; j < nodes.length; j++) {
        if (nodes[j].getNodeType() == NodeType.LATENT) {
            for (int i = 0; i < numFullCases; i++) {
                dsMixed.setInt(i, j, -99);
            }
        } else {
            String name = bayesPm.getVariable(nodes[j]).getName();
            Node variable = dataSet.getVariable(name);
            int index = dataSet.getColumn(variable);
            for (int i = 0; i < numFullCases; i++) {
                dsMixed.setInt(i, j, dataSet.getInt(i, index));
            }
        }
    }
    // System.out.println(dsMixed);
    mixedData = dsMixed;
    allVariables = mixedData.getVariables();
    // Find the bayes net which is parameterized using mixedData or set randomly when that's
    // not possible.
    estimateIM(bayesPm, mixedData);
    // The following DEBUG section tests a case specified by P. Spirtes
    // DEBUG TAIL:   For use with embayes_l1x1x2x3V3.dat
    /*
        Node l1Node = graph.getNode("L1");
        //int l1Index = bayesImMixed.getNodeIndex(l1Node);
        int l1index = estimatedIm.getNodeIndex(l1Node);
        Node x1Node = graph.getNode("X1");
        //int x1Index = bayesImMixed.getNodeIndex(x1Node);
        int x1Index = estimatedIm.getNodeIndex(x1Node);
        Node x2Node = graph.getNode("X2");
        //int x2Index = bayesImMixed.getNodeIndex(x2Node);
        int x2Index = estimatedIm.getNodeIndex(x2Node);
        Node x3Node = graph.getNode("X3");
        //int x3Index = bayesImMixed.getNodeIndex(x3Node);
        int x3Index = estimatedIm.getNodeIndex(x3Node);

        estimatedIm.setProbability(l1index, 0, 0, 0.5);
        estimatedIm.setProbability(l1index, 0, 1, 0.5);

        //bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
        //bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
        estimatedIm.setProbability(x1Index, 0, 0, 0.6);      //p(x1 = 0 | l1 = 0)
        estimatedIm.setProbability(x1Index, 0, 1, 0.4);      //p(x1 = 1 | l1 = 0)
        estimatedIm.setProbability(x1Index, 1, 0, 0.4);      //p(x1 = 0 | l1 = 1)
        estimatedIm.setProbability(x1Index, 1, 1, 0.6);      //p(x1 = 1 | l1 = 1)

        //bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
        //bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x2Index, 1, 0, 0.4);      //p(x2 = 0 | l1 = 1)
        estimatedIm.setProbability(x2Index, 1, 1, 0.6);      //p(x2 = 1 | l1 = 1)
        estimatedIm.setProbability(x2Index, 0, 0, 0.6);      //p(x2 = 0 | l1 = 0)
        estimatedIm.setProbability(x2Index, 0, 1, 0.4);      //p(x2 = 1 | l1 = 0)

        //bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
        //bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x3Index, 1, 0, 0.4);      //p(x3 = 0 | l1 = 1)
        estimatedIm.setProbability(x3Index, 1, 1, 0.6);      //p(x3 = 1 | l1 = 1)
        estimatedIm.setProbability(x3Index, 0, 0, 0.6);      //p(x3 = 0 | l1 = 0)
        estimatedIm.setProbability(x3Index, 0, 1, 0.4);      //p(x3 = 1 | l1 = 0)
        */
    // END of TAIL
    // System.out.println("bayes IM estimated by estimateIM");
    // System.out.println(bayesImMixed);
    // System.out.println(estimatedIm);
    estimatedCounts = new double[nodes.length][][];
    estimatedCountsDenom = new double[nodes.length][];
    condProbs = new double[nodes.length][][];
    for (int i = 0; i < nodes.length; i++) {
        // int numRows = bayesImMixed.getNumRows(i);
        int numRows = estimatedIm.getNumRows(i);
        estimatedCounts[i] = new double[numRows][];
        estimatedCountsDenom[i] = new double[numRows];
        condProbs[i] = new double[numRows][];
        // for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
        for (int j = 0; j < estimatedIm.getNumRows(i); j++) {
            // int numCols = bayesImMixed.getNumColumns(i);
            int numCols = estimatedIm.getNumColumns(i);
            estimatedCounts[i][j] = new double[numCols];
            condProbs[i][j] = new double[numCols];
        }
    }
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList)

Example 15 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class MlBayesImObs method simulateDataHelper.

/**
 * Simulates a sample with the given sample size.
 *
 * @param sampleSize      the sample size.
 * @return the simulated sample as a DataSet.
 */
private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved) {
    int numMeasured = 0;
    int[] map = new int[nodes.length];
    List<Node> variables = new LinkedList<>();
    for (int j = 0; j < nodes.length; j++) {
        if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) {
            continue;
        }
        int numCategories = bayesPm.getNumCategories(nodes[j]);
        List<String> categories = new LinkedList<>();
        for (int k = 0; k < numCategories; k++) {
            categories.add(bayesPm.getCategory(nodes[j], k));
        }
        DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories);
        variables.add(var);
        int index = ++numMeasured - 1;
        map[index] = j;
    }
    DataSet dataSet = new ColtDataSet(sampleSize, variables);
    constructSample(sampleSize, numMeasured, dataSet, map);
    return dataSet;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Aggregations

ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)28 DataSet (edu.cmu.tetrad.data.DataSet)24 Node (edu.cmu.tetrad.graph.Node)21 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)17 LinkedList (java.util.LinkedList)13 Test (org.junit.Test)12 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)9 RandomUtil (edu.cmu.tetrad.util.RandomUtil)6 Parameters (edu.cmu.tetrad.util.Parameters)3 ArrayList (java.util.ArrayList)3 KernelGaussian (edu.cmu.tetrad.search.kernel.KernelGaussian)2 DecimalFormat (java.text.DecimalFormat)2 ParseException (java.text.ParseException)2 KMeans (edu.cmu.tetrad.cluster.KMeans)1 CellTable (edu.cmu.tetrad.data.CellTable)1 DataModelList (edu.cmu.tetrad.data.DataModelList)1 TimeSeriesData (edu.cmu.tetrad.data.TimeSeriesData)1 Dag (edu.cmu.tetrad.graph.Dag)1 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)1 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)1