Search in sources :

Example 21 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TestTransform method testSingleTransforms.

@Test
public void testSingleTransforms() {
    // build a dataset.
    List<Node> list = Arrays.asList((Node) new ContinuousVariable("x"), new ContinuousVariable("y"), new ContinuousVariable("z"));
    DataSet data = new ColtDataSet(2, list);
    data.setDouble(0, 0, 2);
    data.setDouble(1, 0, 3);
    data.setDouble(2, 0, 4);
    data.setDouble(0, 1, 1);
    data.setDouble(1, 1, 6);
    data.setDouble(2, 1, 5);
    data.setDouble(0, 2, 8);
    data.setDouble(1, 2, 8);
    data.setDouble(2, 2, 8);
    DataSet copy = new ColtDataSet((ColtDataSet) data);
    // test transforms on it.
    try {
        String eq = "z = (x + y)";
        Transformation.transform(copy, eq);
        assertTrue(copy.getDouble(0, 2) == 3.0);
        assertTrue(copy.getDouble(1, 2) == 9.0);
        assertTrue(copy.getDouble(2, 2) == 9.0);
        copy = new ColtDataSet((ColtDataSet) data);
        eq = "x = x + 3";
        Transformation.transform(copy, eq);
        assertTrue(copy.getDouble(0, 0) == 5.0);
        assertTrue(copy.getDouble(1, 0) == 6.0);
        assertTrue(copy.getDouble(2, 0) == 7.0);
        copy = new ColtDataSet((ColtDataSet) data);
        eq = "x = pow(x, 2) + y + z";
        Transformation.transform(copy, eq);
        assertTrue(copy.getDouble(0, 0) == 13.0);
        assertTrue(copy.getDouble(1, 0) == 23.0);
        assertTrue(copy.getDouble(2, 0) == 29.0);
    } catch (ParseException ex) {
        fail(ex.getMessage());
    }
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node) ParseException(java.text.ParseException) Test(org.junit.Test)

Example 22 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TestLingamPattern method simulateDataNonNormal.

/**
 * This simulates data by picking random values for the exogenous terms and percolating this information down
 * through the SEM, assuming it is acyclic. Fast for large simulations but hangs for cyclic models.
 *
 * @param sampleSize    > 0.
 * @return the simulated data set.
 */
private DataSet simulateDataNonNormal(SemIm semIm, int sampleSize, List<Distribution> distributions) {
    List<Node> variables = new LinkedList<>();
    List<Node> variableNodes = semIm.getSemPm().getVariableNodes();
    for (Node node : variableNodes) {
        ContinuousVariable var = new ContinuousVariable(node.getName());
        variables.add(var);
    }
    DataSet dataSet = new ColtDataSet(sampleSize, variables);
    // Create some index arrays to hopefully speed up the simulation.
    SemGraph graph = semIm.getSemPm().getGraph();
    List<Node> tierOrdering = graph.getCausalOrdering();
    int[] tierIndices = new int[variableNodes.size()];
    for (int i = 0; i < tierIndices.length; i++) {
        tierIndices[i] = variableNodes.indexOf(tierOrdering.get(i));
    }
    int[][] _parents = new int[variables.size()][];
    for (int i = 0; i < variableNodes.size(); i++) {
        Node node = variableNodes.get(i);
        List<Node> parents = graph.getParents(node);
        for (Iterator<Node> j = parents.iterator(); j.hasNext(); ) {
            Node _node = j.next();
            if (_node.getNodeType() == NodeType.ERROR) {
                j.remove();
            }
        }
        _parents[i] = new int[parents.size()];
        for (int j = 0; j < parents.size(); j++) {
            Node _parent = parents.get(j);
            _parents[i][j] = variableNodes.indexOf(_parent);
        }
    }
    // Do the simulation.
    for (int row = 0; row < sampleSize; row++) {
        for (int i = 0; i < tierOrdering.size(); i++) {
            int col = tierIndices[i];
            Distribution distribution = distributions.get(col);
            // System.out.println(distribution);
            double value = distribution.nextRandom();
            for (int j = 0; j < _parents[col].length; j++) {
                int parent = _parents[col][j];
                value += dataSet.getDouble(row, parent) * semIm.getEdgeCoef().get(parent, col);
            }
            value += semIm.getMeans()[col];
            dataSet.setDouble(row, col, value);
        }
    }
    return dataSet;
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Distribution(edu.cmu.tetrad.util.dist.Distribution)

Example 23 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TestKernelGaussian method testMedianBandwidth.

/**
 * Tests the bandwidth setting to the median distance between points in the sample
 */
@Test
public void testMedianBandwidth() {
    Node X = new ContinuousVariable("X");
    DataSet dataset = new ColtDataSet(5, Arrays.asList(X));
    dataset.setDouble(0, 0, 1);
    dataset.setDouble(1, 0, 2);
    dataset.setDouble(2, 0, 3);
    dataset.setDouble(3, 0, 4);
    dataset.setDouble(4, 0, 5);
    KernelGaussian kernel = new KernelGaussian(dataset, X);
    assertTrue(kernel.getBandwidth() == 2);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node) KernelGaussian(edu.cmu.tetrad.search.kernel.KernelGaussian) Test(org.junit.Test)

Example 24 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class BayesUpdaterClassifierEditor method showClassification.

private void showClassification() {
    int tabIndex = -1;
    for (int i = 0; i < getTabbedPane().getTabCount(); i++) {
        if ("Classification".equals(getTabbedPane().getTitleAt(i))) {
            getTabbedPane().remove(i);
            tabIndex = i;
        }
    }
    // Put the class information into a DataSet.
    int[] classifications = getClassifier().getClassifications();
    double[][] marginals = getClassifier().getMarginals();
    int maxCategory = 0;
    for (int classification : classifications) {
        if (classification > maxCategory) {
            maxCategory = classification;
        }
    }
    List<Node> variables = new LinkedList<>();
    DiscreteVariable targetVariable = classifier.getTargetVariable();
    DiscreteVariable classVar = new DiscreteVariable(targetVariable.getName(), maxCategory + 1);
    variables.add(classVar);
    for (int i = 0; i < marginals.length; i++) {
        String name = "P(" + targetVariable + "=" + i + ")";
        ContinuousVariable scoreVar = new ContinuousVariable(name);
        variables.add(scoreVar);
    }
    classVar.setName("Result");
    DataSet dataSet = new ColtDataSet(classifications.length, variables);
    for (int i = 0; i < classifications.length; i++) {
        dataSet.setInt(i, 0, classifications[i]);
        for (int j = 0; j < marginals.length; j++) {
            dataSet.setDouble(i, j + 1, marginals[j][i]);
        }
    }
    DataDisplay jTable = new DataDisplay(dataSet);
    JScrollPane scroll = new JScrollPane(jTable);
    if (tabIndex == -1) {
        getTabbedPane().add("Classification", scroll);
    } else {
        getTabbedPane().add(scroll, tabIndex);
        getTabbedPane().setTitleAt(tabIndex, "Classification");
    }
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Example 25 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class MlBayesImObs method simulateTimeSeries.

private DataSet simulateTimeSeries(int sampleSize) {
    TimeLagGraph timeSeriesGraph = getBayesPm().getDag().getTimeLagGraph();
    List<Node> variables = new ArrayList<>();
    for (Node node : timeSeriesGraph.getLag0Nodes()) {
        variables.add(new DiscreteVariable(timeSeriesGraph.getNodeId(node).getName()));
    }
    List<Node> lag0Nodes = timeSeriesGraph.getLag0Nodes();
    DataSet fullData = new ColtDataSet(sampleSize, variables);
    Graph contemporaneousDag = timeSeriesGraph.subgraph(lag0Nodes);
    List<Node> tierOrdering = contemporaneousDag.getCausalOrdering();
    int[] tiers = new int[tierOrdering.size()];
    for (int i = 0; i < tierOrdering.size(); i++) {
        tiers[i] = getNodeIndex(tierOrdering.get(i));
    }
    // Construct the sample.
    int[] combination = new int[tierOrdering.size()];
    for (int i = 0; i < sampleSize; i++) {
        int[] point = new int[nodes.length];
        for (int nodeIndex : tiers) {
            double cutoff = RandomUtil.getInstance().nextDouble();
            for (int k = 0; k < getNumParents(nodeIndex); k++) {
                combination[k] = point[getParent(nodeIndex, k)];
            }
            int rowIndex = getRowIndex(nodeIndex, combination);
            double sum = 0.0;
            for (int k = 0; k < getNumColumns(nodeIndex); k++) {
                double probability = getProbability(nodeIndex, rowIndex, k);
                if (Double.isNaN(probability)) {
                    throw new IllegalStateException("Some probability " + "values in the BayesIm are not filled in; " + "cannot simulate data.");
                }
                sum += probability;
                if (sum >= cutoff) {
                    point[nodeIndex] = k;
                    break;
                }
            }
        }
    }
    return fullData;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Aggregations

ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)28 DataSet (edu.cmu.tetrad.data.DataSet)24 Node (edu.cmu.tetrad.graph.Node)21 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)17 LinkedList (java.util.LinkedList)13 Test (org.junit.Test)12 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)9 RandomUtil (edu.cmu.tetrad.util.RandomUtil)6 Parameters (edu.cmu.tetrad.util.Parameters)3 ArrayList (java.util.ArrayList)3 KernelGaussian (edu.cmu.tetrad.search.kernel.KernelGaussian)2 DecimalFormat (java.text.DecimalFormat)2 ParseException (java.text.ParseException)2 KMeans (edu.cmu.tetrad.cluster.KMeans)1 CellTable (edu.cmu.tetrad.data.CellTable)1 DataModelList (edu.cmu.tetrad.data.DataModelList)1 TimeSeriesData (edu.cmu.tetrad.data.TimeSeriesData)1 Dag (edu.cmu.tetrad.graph.Dag)1 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)1 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)1