Search in sources :

Example 6 with TimeLagGraph

use of edu.cmu.tetrad.graph.TimeLagGraph in project tetrad by cmu-phil.

the class MlBayesIm method simulateTimeSeries.

private DataSet simulateTimeSeries(int sampleSize) {
    TimeLagGraph timeSeriesGraph = getBayesPm().getDag().getTimeLagGraph();
    List<Node> variables = new ArrayList<>();
    for (Node node : timeSeriesGraph.getLag0Nodes()) {
        final DiscreteVariable e = new DiscreteVariable(timeSeriesGraph.getNodeId(node).getName());
        e.setNodeType(node.getNodeType());
        variables.add(e);
    }
    List<Node> lag0Nodes = timeSeriesGraph.getLag0Nodes();
    // DataSet fullData = new ColtDataSet(sampleSize, variables);
    DataSet fullData = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables);
    Graph contemporaneousDag = timeSeriesGraph.subgraph(lag0Nodes);
    List<Node> tierOrdering = contemporaneousDag.getCausalOrdering();
    int[] tiers = new int[tierOrdering.size()];
    for (int i = 0; i < tierOrdering.size(); i++) {
        tiers[i] = getNodeIndex(tierOrdering.get(i));
    }
    // Construct the sample.
    int[] combination = new int[tierOrdering.size()];
    for (int i = 0; i < sampleSize; i++) {
        int[] point = new int[nodes.length];
        for (int nodeIndex : tiers) {
            double cutoff = RandomUtil.getInstance().nextDouble();
            for (int k = 0; k < getNumParents(nodeIndex); k++) {
                combination[k] = point[getParent(nodeIndex, k)];
            }
            int rowIndex = getRowIndex(nodeIndex, combination);
            double sum = 0.0;
            for (int k = 0; k < getNumColumns(nodeIndex); k++) {
                double probability = getProbability(nodeIndex, rowIndex, k);
                if (Double.isNaN(probability)) {
                    throw new IllegalStateException("Some probability " + "values in the BayesIm are not filled in; " + "cannot simulate data.");
                }
                sum += probability;
                if (sum >= cutoff) {
                    point[nodeIndex] = k;
                    break;
                }
            }
        }
    }
    return fullData;
}
Also used : Node(edu.cmu.tetrad.graph.Node) TimeLagGraph(edu.cmu.tetrad.graph.TimeLagGraph) Graph(edu.cmu.tetrad.graph.Graph) TimeLagGraph(edu.cmu.tetrad.graph.TimeLagGraph)

Aggregations

TimeLagGraph (edu.cmu.tetrad.graph.TimeLagGraph)6 Node (edu.cmu.tetrad.graph.Node)5 Graph (edu.cmu.tetrad.graph.Graph)2 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)1 SingleGraph (edu.cmu.tetrad.algcomparison.graph.SingleGraph)1 LagGraphParams (edu.cmu.tetrad.gene.tetrad.gene.graph.LagGraphParams)1 RandomActiveLagGraph (edu.cmu.tetrad.gene.tetrad.gene.graph.RandomActiveLagGraph)1 LaggedFactor (edu.cmu.tetrad.gene.tetrad.gene.history.LaggedFactor)1 BooleanGlassGeneIm (edu.cmu.tetrad.gene.tetradapp.model.BooleanGlassGeneIm)1 BooleanGlassGenePm (edu.cmu.tetrad.gene.tetradapp.model.BooleanGlassGenePm)1 GraphNode (edu.cmu.tetrad.graph.GraphNode)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1 Test (org.junit.Test)1