Search in sources :

Example 16 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifier method classify.

/**
 * Computes and returns the crosstabulation of observed versus estimated
 * values of the target variable as described above.
 */
public int[] classify() {
    if (targetVariable == null) {
        throw new NullPointerException("Target not set.");
    }
    // Create an updater for the instantiated Bayes net.
    BayesUpdater bayesUpdater = new RowSummingExactUpdater(getBayesIm());
    // Get the raw data from the dataset to be classified, the number
    // of variables and the number of cases.
    int nvars = getBayesImVars().size();
    int ncases = testData.getNumRows();
    int[] varIndices = new int[nvars];
    List<Node> dataVars = testData.getVariables();
    for (int i = 0; i < nvars; i++) {
        DiscreteVariable variable = (DiscreteVariable) getBayesImVars().get(i);
        if (variable == targetVariable) {
            continue;
        }
        varIndices[i] = dataVars.indexOf(variable);
        if (varIndices[i] == -1) {
            throw new IllegalArgumentException("Can't find the (non-target) variable " + variable + " in the data. Either it's not there, or else its " + "categories are in a different order.");
        }
    }
    DataSet selectedData = testData.subsetColumns(varIndices);
    this.numCases = ncases;
    int[] estimatedValues = new int[ncases];
    int numTargetCategories = targetVariable.getNumCategories();
    double[][] probOfClassifiedValues = new double[numTargetCategories][ncases];
    Arrays.fill(estimatedValues, -1);
    // and Bayesian updating.
    for (int i = 0; i < ncases; i++) {
        // Create an Evidence instance for the instantiated Bayes net
        // which will allow that updating.
        Evidence evidence = Evidence.tautology(getBayesIm());
        // Let the target variable range over all its values.
        int itarget = evidence.getNodeIndex(targetVariable.getName());
        evidence.getProposition().setVariable(itarget, true);
        this.missingValueCaseFound = false;
        // this case.
        for (int j = 0; j < getBayesImVars().size(); j++) {
            if (j == getBayesImVars().indexOf(targetVariable)) {
                continue;
            }
            int observedValue = selectedData.getInt(i, j);
            if (observedValue == DiscreteVariable.MISSING_VALUE) {
                this.missingValueCaseFound = true;
                continue;
            }
            String jName = getBayesImVars().get(j).getName();
            int jIndex = evidence.getNodeIndex(jName);
            evidence.getProposition().setCategory(jIndex, observedValue);
        }
        // Update using those values.
        bayesUpdater.setEvidence(evidence);
        // for each possible value of target compute its probability in
        // the updated Bayes net.  Select the value with the highest
        // probability as the estimated value.
        Node targetNode = getBayesIm().getNode(targetVariable.getName());
        int indexTargetBN = getBayesIm().getNodeIndex(targetNode);
        // Straw man values--to be replaced.
        int estimatedValue = -1;
        // if (numTargetCategories == 2) {
        // for (int j = 0; j < numTargetCategories; j++) {
        // double marginal =
        // bayesUpdater.getMarginal(indexTargetBN, j);
        // probOfClassifiedValues[j][i] = marginal;
        // probOfClassifiedValues[1 - j][i] = 1.0 - marginal;
        // 
        // if (targetCategory == j) {
        // if (marginal > binaryCutoff) {
        // estimatedValue = j;
        // } else {
        // estimatedValue = 1 - j;
        // }
        // 
        // break;
        // }
        // }
        // } else
        {
            double highestProb = -0.1;
            for (int j = 0; j < numTargetCategories; j++) {
                double marginal = bayesUpdater.getMarginal(indexTargetBN, j);
                probOfClassifiedValues[j][i] = marginal;
                if (marginal >= highestProb) {
                    highestProb = marginal;
                    estimatedValue = j;
                }
            }
        }
        // training dataset.  If that happens skip the case.
        if (estimatedValue < 0) {
            TetradLogger.getInstance().log("details", "Case " + i + " does not return valid marginal.");
            for (int m = 0; m < nvars; m++) {
                // System.out.print(getBayesImVars()
                // .get(m).getNode());
                TetradLogger.getInstance().log("details", "  " + selectedData.getDouble(i, m));
            }
            estimatedValues[i] = DiscreteVariable.MISSING_VALUE;
            continue;
        }
        estimatedValues[i] = estimatedValue;
    }
    this.classifications = estimatedValues;
    this.marginals = probOfClassifiedValues;
    return estimatedValues;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 17 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class LogisticRegressionRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    outGraph = new EdgeListGraph();
    if (regressorNames == null || regressorNames.isEmpty() || targetName == null) {
        report = "Response and predictor variables not set.";
        return;
    }
    if (regressorNames.contains(targetName)) {
        report = "Response must not be a predictor.";
        return;
    }
    DataSet regressorsDataSet = dataSets.get(getModelIndex()).copy();
    Node target = regressorsDataSet.getVariable(targetName);
    regressorsDataSet.removeColumn(target);
    List<String> names = regressorsDataSet.getVariableNames();
    // Get the list of regressors selected by the user
    List<Node> regressorNodes = new ArrayList<>();
    for (String s : regressorNames) {
        regressorNodes.add(dataSets.get(getModelIndex()).getVariable(s));
    }
    // If the user selected none, use them all
    if (regressorNames.size() > 0) {
        for (String name1 : names) {
            Node regressorVar = regressorsDataSet.getVariable(name1);
            if (!regressorNames.contains(regressorVar.getName())) {
                regressorsDataSet.removeColumn(regressorVar);
            }
        }
    }
    int ncases = regressorsDataSet.getNumRows();
    int nvars = regressorsDataSet.getNumColumns();
    double[][] regressors = new double[nvars][ncases];
    for (int i = 0; i < nvars; i++) {
        for (int j = 0; j < ncases; j++) {
            regressors[i][j] = regressorsDataSet.getDouble(j, i);
        }
    }
    LogisticRegression logRegression = new LogisticRegression(dataSets.get(getModelIndex()));
    logRegression.setAlpha(alpha);
    this.result = logRegression.regress((DiscreteVariable) target, regressorNodes);
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ArrayList(java.util.ArrayList) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression)

Example 18 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class EmBayesEstimator method expectation.

/**
 * This method takes an instantiated Bayes net (BayesIm) whose graph include
 * all the variables (observed and latent) and computes estimated counts
 * using the data in the DataSet mixedData. </p> The counts that are
 * estimated correspond to cells in the conditional probability tables of
 * the Bayes net.  The outermost loop (indexed by j) is over the set of
 * variables.  If the variable has no parents, each case in the dataset is
 * examined and the count for the observed value of the variables is
 * increased by 1.0; if the value of the variable is missing the marginal
 * probabilities its values given the values of the variables that are
 * available for that case are used to increment the corresponding estimated
 * counts. </p> If a variable has parents then there is a loop which steps
 * through all possible sets of values of its parents.  This loop is indexed
 * by the variable "row".  Each case in the dataset is examined.  It the
 * variable and all its parents have values in the case the corresponding
 * estimated counts are incremented by 1.0.  If the variable or any of its
 * parents have missing values, the joint marginal is computed for the
 * variable and the set of values of its parents corresponding to "row" and
 * the corresponding estimated counts are incremented by the appropriate
 * probability. </p> The estimated counts are stored in the double[][][]
 * array estimatedCounts.  The count (possibly fractional) of the number of
 * times each combination of parent values occurs is stored in the
 * double[][] array estimatedCountsDenom.  These two arrays are used to
 * compute the estimated conditional probabilities of the output Bayes net.
 */
private void expectation(BayesIm inputBayesIm) {
    // System.out.println("Entered method expectation.");
    int numCases = mixedData.getNumRows();
    // StoredCellEstCounts estCounts = new StoredCellEstCounts(variables);
    int numVariables = allVariables.size();
    RowSummingExactUpdater rseu = new RowSummingExactUpdater(inputBayesIm);
    for (int j = 0; j < numVariables; j++) {
        DiscreteVariable var = (DiscreteVariable) allVariables.get(j);
        String varName = var.getName();
        Node varNode = graph.getNode(varName);
        int varIndex = inputBayesIm.getNodeIndex(varNode);
        int[] parentVarIndices = inputBayesIm.getParents(varIndex);
        // This segment is for variables with no parents:
        if (parentVarIndices.length == 0) {
            // System.out.println("No parents");
            for (int col = 0; col < var.getNumCategories(); col++) {
                estimatedCounts[j][0][col] = 0.0;
            }
            for (int i = 0; i < numCases; i++) {
                // If this case has a value for ar
                if (mixedData.getInt(i, j) != -99) {
                    estimatedCounts[j][0][mixedData.getInt(i, j)] += 1.0;
                // System.out.println("Adding 1.0 to " + varName +
                // " row 0 category " + mixedData[j][i]);
                } else {
                    // find marginal probability, given obs data in this case, p(v=0)
                    Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);
                    boolean existsEvidence = false;
                    // Define evidence for updating by using the values of the other vars.
                    for (int k = 0; k < numVariables; k++) {
                        if (k == j) {
                            continue;
                        }
                        Node otherVar = allVariables.get(k);
                        if (mixedData.getInt(i, k) == -99) {
                            continue;
                        }
                        existsEvidence = true;
                        String otherVarName = otherVar.getName();
                        Node otherNode = graph.getNode(otherVarName);
                        int otherIndex = inputBayesIm.getNodeIndex(otherNode);
                        evidenceThisCase.getProposition().setCategory(otherIndex, mixedData.getInt(i, k));
                    }
                    if (!existsEvidence) {
                        // No other variable contained useful data
                        continue;
                    }
                    rseu.setEvidence(evidenceThisCase);
                    for (int m = 0; m < var.getNumCategories(); m++) {
                        estimatedCounts[j][0][m] += rseu.getMarginal(varIndex, m);
                    // System.out.println("Adding " + p + " to " + varName +
                    // " row 0 category " + m);
                    // find marginal probability, given obs data in this case, p(v=1)
                    // estimatedCounts[j][0][1] += 0.5;
                    }
                }
            }
        // Print estimated counts:
        // System.out.println("Estimated counts:  ");
        // Print counts for each value of this variable with no parents.
        // for(int m = 0; m < ar.getNumSplits(); m++)
        // System.out.print("    " + m + " " + estimatedCounts[j][0][m]);
        // System.out.println();
        } else {
            // For variables with parents:
            int numRows = inputBayesIm.getNumRows(varIndex);
            for (int row = 0; row < numRows; row++) {
                int[] parValues = inputBayesIm.getParentValues(varIndex, row);
                estimatedCountsDenom[varIndex][row] = 0.0;
                for (int col = 0; col < var.getNumCategories(); col++) {
                    estimatedCounts[varIndex][row][col] = 0.0;
                }
                for (int i = 0; i < numCases; i++) {
                    // for a case where the parent values = parValues increment the estCount
                    boolean parentMatch = true;
                    for (int p = 0; p < parentVarIndices.length; p++) {
                        if (parValues[p] != mixedData.getInt(i, parentVarIndices[p]) && mixedData.getInt(i, parentVarIndices[p]) != -99) {
                            parentMatch = false;
                            break;
                        }
                    }
                    if (!parentMatch) {
                        // Not a matching case; go to next.
                        continue;
                    }
                    boolean parentMissing = false;
                    for (int parentVarIndice : parentVarIndices) {
                        if (mixedData.getInt(i, parentVarIndice) == -99) {
                            parentMissing = true;
                            break;
                        }
                    }
                    if (mixedData.getInt(i, j) != -99 && !parentMissing) {
                        estimatedCounts[j][row][mixedData.getInt(i, j)] += 1.0;
                        estimatedCountsDenom[j][row] += 1.0;
                        // Next case
                        continue;
                    }
                    // for a case with missing data (either ar or one of its parents)
                    // compute the joint marginal
                    // distribution for ar & this combination of values of its parents
                    // and update the estCounts accordingly
                    // To compute marginals create the evidence
                    boolean existsEvidence = false;
                    Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);
                    if (!existsEvidence) {
                        continue;
                    }
                    rseu.setEvidence(evidenceThisCase);
                    estimatedCountsDenom[j][row] += rseu.getJointMarginal(parentVarIndices, parValues);
                    int[] parPlusChildIndices = new int[parentVarIndices.length + 1];
                    int[] parPlusChildValues = new int[parentVarIndices.length + 1];
                    parPlusChildIndices[0] = varIndex;
                    for (int pc = 1; pc < parPlusChildIndices.length; pc++) {
                        parPlusChildIndices[pc] = parentVarIndices[pc - 1];
                        parPlusChildValues[pc] = parValues[pc - 1];
                    }
                    for (int m = 0; m < var.getNumCategories(); m++) {
                        parPlusChildValues[0] = m;
                        /*
                            if(varName.equals("X1") && i == 0 ) {
                                System.out.println("Calling getJointMarginal with parvalues");
                                for(int k = 0; k < parPlusChildIndices.length; k++) {
                                    int pIndex = parPlusChildIndices[k];
                                    Node pNode = inputBayesIm.getIndex(pIndex);
                                    String pName = pNode.getNode();
                                    System.out.println(pName + " " + parPlusChildValues[k]);
                                }
                            }
                            */
                        /*
                            if(varName.equals("X1") && i == 0 ) {
                                System.out.println("Evidence = " + evidenceThisCase);
                                //int[] vars = {l1Index, x1Index};
                                Node nodex1 = inputBayesIm.getIndex("X1");
                                int x1Index = inputBayesIm.getNodeIndex(nodex1);
                                Node nodel1 = inputBayesIm.getIndex("L1");
                                int l1Index = inputBayesIm.getNodeIndex(nodel1);

                                int[] vars = {l1Index, x1Index};
                                int[] vals = {0, 0};
                                double ptest = rseu.getJointMarginal(vars, vals);
                                System.out.println("Joint marginal (X1=0, L1 = 0) = " + p);
                            }
                            */
                        estimatedCounts[j][row][m] += rseu.getJointMarginal(parPlusChildIndices, parPlusChildValues);
                    // System.out.println("Case " + i + " parent values ");
                    // for (int pp = 0; pp < parentVarIndices.length; pp++) {
                    // Variable par = (Variable) allVariables.get(parentVarIndices[pp]);
                    // System.out.print("    " + par.getNode() + " " + parValues[pp]);
                    // }
                    // System.out.println();
                    // System.out.println("Adding " + p + " to " + varName +
                    // " row " + row + " category " + m);
                    }
                // }
                }
            // Print estimated counts:
            // System.out.println("Estimated counts:  ");
            // System.out.println("    Parent values:  ");
            // for (int i = 0; i < parentVarIndices.length; i++) {
            // Variable par = (Variable) allVariables.get(parentVarIndices[i]);
            // System.out.print("    " + par.getNode() + " " + parValues[i] + "    ");
            // }
            // System.out.println();
            // for(int m = 0; m < ar.getNumSplits(); m++)
            // System.out.print("    " + m + " " + estimatedCounts[j][row][m]);
            // System.out.println();
            }
        }
    // else
    }
    // j < numVariables
    BayesIm outputBayesIm = new MlBayesIm(bayesPm);
    for (int j = 0; j < nodes.length; j++) {
        DiscreteVariable var = (DiscreteVariable) allVariables.get(j);
        String varName = var.getName();
        Node varNode = graph.getNode(varName);
        int varIndex = inputBayesIm.getNodeIndex(varNode);
        // int[] parentVarIndices = inputBayesIm.getParents(varIndex);
        int numRows = inputBayesIm.getNumRows(j);
        // System.out.println("Conditional probabilities for variable " + varName);
        int numCols = inputBayesIm.getNumColumns(j);
        if (numRows == 1) {
            double sum = 0.0;
            for (int m = 0; m < numCols; m++) {
                sum += estimatedCounts[j][0][m];
            }
            for (int m = 0; m < numCols; m++) {
                condProbs[j][0][m] = estimatedCounts[j][0][m] / sum;
                // System.out.print("  " + condProbs[j][0][m]);
                outputBayesIm.setProbability(varIndex, 0, m, condProbs[j][0][m]);
            }
        // System.out.println();
        } else {
            for (int row = 0; row < numRows; row++) {
                for (int m = 0; m < numCols; m++) {
                    if (estimatedCountsDenom[j][row] != 0.0) {
                        condProbs[j][row][m] = estimatedCounts[j][row][m] / estimatedCountsDenom[j][row];
                    } else {
                        condProbs[j][row][m] = Double.NaN;
                    }
                    // System.out.print("  " + condProbs[j][row][m]);
                    outputBayesIm.setProbability(varIndex, row, m, condProbs[j][row][m]);
                }
            // System.out.println();
            }
        }
    }
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node)

Example 19 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class EmBayesEstimator method initialize.

private void initialize() {
    DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5);
    observedIm = DirichletEstimator.estimate(prior, dataSet);
    // MLBayesEstimator dirichEst = new MLBayesEstimator();
    // observedIm = dirichEst.estimate(bayesPmObs, dataSet);
    // System.out.println("Estimated Bayes IM for Measured Variables:  ");
    // System.out.println(observedIm);
    // mixedData should be ddsNm with new columns for the latent variables.
    // Each such column should contain missing data for each case.
    int numFullCases = dataSet.getNumRows();
    List<Node> variables = new LinkedList<>();
    for (Node node : nodes) {
        if (node.getNodeType() == NodeType.LATENT) {
            int numCategories = bayesPm.getNumCategories(node);
            DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
            latentVar.setNodeType(NodeType.LATENT);
            variables.add(latentVar);
        } else {
            String name = bayesPm.getVariable(node).getName();
            Node variable = dataSet.getVariable(name);
            variables.add(variable);
        }
    }
    DataSet dsMixed = new ColtDataSet(numFullCases, variables);
    for (int j = 0; j < nodes.length; j++) {
        if (nodes[j].getNodeType() == NodeType.LATENT) {
            for (int i = 0; i < numFullCases; i++) {
                dsMixed.setInt(i, j, -99);
            }
        } else {
            String name = bayesPm.getVariable(nodes[j]).getName();
            Node variable = dataSet.getVariable(name);
            int index = dataSet.getColumn(variable);
            for (int i = 0; i < numFullCases; i++) {
                dsMixed.setInt(i, j, dataSet.getInt(i, index));
            }
        }
    }
    // System.out.println(dsMixed);
    mixedData = dsMixed;
    allVariables = mixedData.getVariables();
    // Find the bayes net which is parameterized using mixedData or set randomly when that's
    // not possible.
    estimateIM(bayesPm, mixedData);
    // The following DEBUG section tests a case specified by P. Spirtes
    // DEBUG TAIL:   For use with embayes_l1x1x2x3V3.dat
    /*
        Node l1Node = graph.getNode("L1");
        //int l1Index = bayesImMixed.getNodeIndex(l1Node);
        int l1index = estimatedIm.getNodeIndex(l1Node);
        Node x1Node = graph.getNode("X1");
        //int x1Index = bayesImMixed.getNodeIndex(x1Node);
        int x1Index = estimatedIm.getNodeIndex(x1Node);
        Node x2Node = graph.getNode("X2");
        //int x2Index = bayesImMixed.getNodeIndex(x2Node);
        int x2Index = estimatedIm.getNodeIndex(x2Node);
        Node x3Node = graph.getNode("X3");
        //int x3Index = bayesImMixed.getNodeIndex(x3Node);
        int x3Index = estimatedIm.getNodeIndex(x3Node);

        estimatedIm.setProbability(l1index, 0, 0, 0.5);
        estimatedIm.setProbability(l1index, 0, 1, 0.5);

        //bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
        //bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
        estimatedIm.setProbability(x1Index, 0, 0, 0.6);      //p(x1 = 0 | l1 = 0)
        estimatedIm.setProbability(x1Index, 0, 1, 0.4);      //p(x1 = 1 | l1 = 0)
        estimatedIm.setProbability(x1Index, 1, 0, 0.4);      //p(x1 = 0 | l1 = 1)
        estimatedIm.setProbability(x1Index, 1, 1, 0.6);      //p(x1 = 1 | l1 = 1)

        //bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
        //bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x2Index, 1, 0, 0.4);      //p(x2 = 0 | l1 = 1)
        estimatedIm.setProbability(x2Index, 1, 1, 0.6);      //p(x2 = 1 | l1 = 1)
        estimatedIm.setProbability(x2Index, 0, 0, 0.6);      //p(x2 = 0 | l1 = 0)
        estimatedIm.setProbability(x2Index, 0, 1, 0.4);      //p(x2 = 1 | l1 = 0)

        //bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
        //bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
        estimatedIm.setProbability(x3Index, 1, 0, 0.4);      //p(x3 = 0 | l1 = 1)
        estimatedIm.setProbability(x3Index, 1, 1, 0.6);      //p(x3 = 1 | l1 = 1)
        estimatedIm.setProbability(x3Index, 0, 0, 0.6);      //p(x3 = 0 | l1 = 0)
        estimatedIm.setProbability(x3Index, 0, 1, 0.4);      //p(x3 = 1 | l1 = 0)
        */
    // END of TAIL
    // System.out.println("bayes IM estimated by estimateIM");
    // System.out.println(bayesImMixed);
    // System.out.println(estimatedIm);
    estimatedCounts = new double[nodes.length][][];
    estimatedCountsDenom = new double[nodes.length][];
    condProbs = new double[nodes.length][][];
    for (int i = 0; i < nodes.length; i++) {
        // int numRows = bayesImMixed.getNumRows(i);
        int numRows = estimatedIm.getNumRows(i);
        estimatedCounts[i] = new double[numRows][];
        estimatedCountsDenom[i] = new double[numRows];
        condProbs[i] = new double[numRows][];
        // for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
        for (int j = 0; j < estimatedIm.getNumRows(i); j++) {
            // int numCols = bayesImMixed.getNumColumns(i);
            int numCols = estimatedIm.getNumColumns(i);
            estimatedCounts[i][j] = new double[numCols];
            condProbs[i][j] = new double[numCols];
        }
    }
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList)

Example 20 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class MlBayesImObs method simulateDataHelper.

/**
 * Simulates a sample with the given sample size.
 *
 * @param sampleSize      the sample size.
 * @return the simulated sample as a DataSet.
 */
private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved) {
    int numMeasured = 0;
    int[] map = new int[nodes.length];
    List<Node> variables = new LinkedList<>();
    for (int j = 0; j < nodes.length; j++) {
        if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) {
            continue;
        }
        int numCategories = bayesPm.getNumCategories(nodes[j]);
        List<String> categories = new LinkedList<>();
        for (int k = 0; k < numCategories; k++) {
            categories.add(bayesPm.getCategory(nodes[j], k));
        }
        DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories);
        variables.add(var);
        int index = ++numMeasured - 1;
        map[index] = j;
    }
    DataSet dataSet = new ColtDataSet(sampleSize, variables);
    constructSample(sampleSize, numMeasured, dataSet, map);
    return dataSet;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet)

Aggregations

DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)56 Node (edu.cmu.tetrad.graph.Node)37 DataSet (edu.cmu.tetrad.data.DataSet)18 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)11 LinkedList (java.util.LinkedList)9 Test (org.junit.Test)5 ArrayList (java.util.ArrayList)4 Dag (edu.cmu.tetrad.graph.Dag)3 NumberFormat (java.text.NumberFormat)3 Element (nu.xom.Element)3 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 List (java.util.List)2 Elements (nu.xom.Elements)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)1 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 BoxDataSet (edu.cmu.tetrad.data.BoxDataSet)1