Search in sources :

Example 26 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesXmlParser method makeBayesPm.

private BayesPm makeBayesPm(List<Node> variables, Element element1) {
    if (!"parents".equals(element1.getQualifiedName())) {
        throw new IllegalArgumentException("Expecting 'parents' element.");
    }
    Dag graph = new Dag();
    for (Node variable : variables) {
        graph.addNode(variable);
    }
    Elements elements = element1.getChildElements();
    for (int i = 0; i < elements.size(); i++) {
        Element e1 = elements.get(i);
        if (!"parentsFor".equals(e1.getQualifiedName())) {
            throw new IllegalArgumentException("Expecting 'parentsFor' element.");
        }
        String varName = e1.getAttributeValue("name");
        Node var = namesToVars.get(varName);
        Elements elements1 = e1.getChildElements();
        for (int j = 0; j < elements1.size(); j++) {
            Element e2 = elements1.get(j);
            if (!"parent".equals(e2.getQualifiedName())) {
                throw new IllegalArgumentException("Expecting 'parent' element.");
            }
            String parentName = e2.getAttributeValue("name");
            Node parent = namesToVars.get(parentName);
            graph.addDirectedEdge(parent, var);
        }
    }
    BayesPm bayesPm = new BayesPm(graph);
    for (Node variable1 : variables) {
        DiscreteVariable graphVariable = (DiscreteVariable) variable1;
        List<String> categories = graphVariable.getCategories();
        bayesPm.setCategories(graphVariable, categories);
    }
    return bayesPm;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node) Element(nu.xom.Element) Dag(edu.cmu.tetrad.graph.Dag) Elements(nu.xom.Elements)

Example 27 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesXmlParser method getVariables.

private List<Node> getVariables(Element element0) {
    if (!"bnVariables".equals(element0.getQualifiedName())) {
        throw new IllegalArgumentException("Expecting 'bnVariables' element.");
    }
    List<Node> variables = new LinkedList<>();
    Elements elements = element0.getChildElements();
    for (int i = 0; i < elements.size(); i++) {
        Element e1 = elements.get(i);
        Elements e2Elements = e1.getChildElements();
        if (!"discreteVariable".equals(e1.getQualifiedName())) {
            throw new IllegalArgumentException("Expecting 'discreteVariable' " + "element.");
        }
        String name = e1.getAttributeValue("name");
        String isLatentVal = e1.getAttributeValue("latent");
        boolean isLatent = (isLatentVal != null) && ((isLatentVal.equals("yes")));
        Integer x = new Integer(e1.getAttributeValue("x"));
        Integer y = new Integer(e1.getAttributeValue("y"));
        int numCategories = e2Elements.size();
        List<String> categories = new LinkedList<>();
        for (int j = 0; j < numCategories; j++) {
            Element e2 = e2Elements.get(j);
            if (!"category".equals(e2.getQualifiedName())) {
                throw new IllegalArgumentException("Expecting 'category' " + "element.");
            }
            categories.add(e2.getAttributeValue("name"));
        }
        DiscreteVariable var = new DiscreteVariable(name, categories);
        if (isLatent) {
            var.setNodeType(NodeType.LATENT);
        }
        var.setCenterX(x);
        var.setCenterY(y);
        variables.add(var);
    }
    namesToVars = new HashMap<>();
    for (Node v : variables) {
        String name = v.getName();
        namesToVars.put(name, v);
    }
    return variables;
}
Also used : Node(edu.cmu.tetrad.graph.Node) Element(nu.xom.Element) Elements(nu.xom.Elements) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 28 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BdeMetric method computeObservedCounts.

// public double scoreLnGam() {
// 
// double[][][] priorProbs;
// double[][] priorProbsRowSum;
// 
// Graph graph = bayesPm.getDag();
// 
// int n = graph.getNumNodes();
// 
// observedCounts = new int[n][][];
// priorProbs = new double[n][][];
// 
// int[][] observedCountsRowSum = new int[n][];
// priorProbsRowSum = new double[n][];
// 
// bayesIm = new MlBayesIm(bayesPm);
// 
// for (int i = 0; i < n; i++) {
// //int numRows = bayesImMixed.getNumRows(i);
// int numRows = bayesIm.getNumRows(i);
// observedCounts[i] = new int[numRows][];
// priorProbs[i] = new double[numRows][];
// 
// observedCountsRowSum[i] = new int[numRows];
// priorProbsRowSum[i] = new double[numRows];
// 
// //for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
// for (int j = 0; j < numRows; j++) {
// 
// observedCountsRowSum[i][j] = 0;
// priorProbsRowSum[i][j] = 0;
// 
// //int numCols = bayesImMixed.getNumColumns(i);
// int numCols = bayesIm.getNumColumns(i);
// observedCounts[i][j] = new int[numCols];
// priorProbs[i][j] = new double[numCols];
// }
// }
// 
// //At this point set values in both observedCounts and priorProbs
// computeObservedCounts();
// //Set all priorProbs (i.e. estimated counts) to 1.0.  Eventually they may be
// //supplied as a parameter of the constructor of this class.
// for (int i = 0; i < n; i++) {
// for (int j = 0; j < bayesIm.getNumRows(i); j++) {
// for (int k = 0; k < bayesIm.getNumColumns(i); k++) {
// priorProbs[i][j][k] = 1.0;
// }
// }
// }
// 
// 
// for (int i = 0; i < n; i++) {
// for (int j = 0; j < bayesIm.getNumRows(i); j++) {
// for (int k = 0; k < bayesIm.getNumColumns(i); k++) {
// observedCountsRowSum[i][j] += observedCounts[i][j][k];
// priorProbsRowSum[i][j] += priorProbs[i][j][k];
// }
// }
// }
// 
// //double outerProduct = 1.0;
// double sum = 0.0;
// 
// //Debug print
// //System.out.println("counts and priors");
// //for(int i = 0; i < n; i++)
// //    for(int j = 0; j < bayesIm.getNumRows(i); j++) {
// //        System.out.println(observedCountsRowSum[i][j] + " " + priorProbsRowSum[i][j]);
// //    }
// 
// for (int i = 0; i < n; i++) {
// 
// int qi = bayesIm.getNumRows(i);
// //double prodj = 1.0;
// double sumj = 0.0;
// for (int j = 0; j < qi; j++) {
// 
// try {
// double numerator =
// ProbUtils.lngamma(priorProbsRowSum[i][j]);
// double denom = ProbUtils.lngamma(priorProbsRowSum[i][j] +
// observedCountsRowSum[i][j]);
// //System.out.println("num = " + numerator + " denom = " + denom);
// sumj += (numerator - denom);
// } catch (Exception e) {
// e.printStackTrace();
// }
// 
// int ri = bayesIm.getNumColumns(i);
// 
// //double prodk = 1.0;
// double sumk = 0.0;
// for (int k = 0; k < ri; k++) {
// try {
// sumk += ProbUtils.lngamma(
// priorProbs[i][j][k] + observedCounts[i][j][k]) -
// ProbUtils.lngamma(priorProbs[i][j][k]);
// } catch (Exception e) {
// e.printStackTrace();
// }
// }
// 
// sumj += sumk;
// }
// sum += sumj;
// }
// 
// return sum;
// }
private void computeObservedCounts() {
    for (int j = 0; j < dataSet.getNumColumns(); j++) {
        DiscreteVariable var = (DiscreteVariable) dataSet.getVariables().get(j);
        String varName = var.getName();
        Node varNode = bayesPm.getDag().getNode(varName);
        int varIndex = bayesIm.getNodeIndex(varNode);
        int[] parentVarIndices = bayesIm.getParents(varIndex);
        // This segment is for variables with no parents:
        if (parentVarIndices.length == 0) {
            // System.out.println("No parents");
            for (int col = 0; col < var.getNumCategories(); col++) {
                observedCounts[j][0][col] = 0;
            }
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                // System.out.println("Case " + i);
                // If this case has a value for ar
                observedCounts[j][0][dataSet.getInt(i, j)] += 1.0;
            // System.out.println("Adding 1.0 to " + varName +
            // " row 0 category " + mixedData[j][i]);
            }
        // Print estimated counts:
        // System.out.println("Estimated counts:  ");
        // Print counts for each value of this variable with no parents.
        // for(int m = 0; m < ar.getNumSplits(); m++)
        // System.out.print("    " + m + " " + observedCounts[j][0][m]);
        // System.out.println();
        } else {
            // For variables with parents:
            int numRows = bayesIm.getNumRows(varIndex);
            for (int row = 0; row < numRows; row++) {
                int[] parValues = bayesIm.getParentValues(varIndex, row);
                for (int col = 0; col < var.getNumCategories(); col++) {
                    observedCounts[varIndex][row][col] = 0;
                }
                for (int i = 0; i < dataSet.getNumRows(); i++) {
                    // for a case where the parent values = parValues increment the estCount
                    boolean parentMatch = true;
                    for (int p = 0; p < parentVarIndices.length; p++) {
                        if (parValues[p] != dataSet.getInt(i, parentVarIndices[p])) {
                            parentMatch = false;
                            break;
                        }
                    }
                    if (!parentMatch) {
                        // Not a matching case; go to next.
                        continue;
                    }
                    observedCounts[j][row][dataSet.getInt(i, j)] += 1;
                }
            // }
            // Print estimated counts:
            // System.out.println("Estimated counts:  ");
            // System.out.println("    Parent values:  ");
            // for (int i = 0; i < parentVarIndices.length; i++) {
            // Variable par = (Variable) dataSet.getVariableNames().get(parentVarIndices[i]);
            // System.out.print("    " + par.getNode() + " " + parValues[i] + "    ");
            // }
            // System.out.println();
            // for(int m = 0; m < ar.getNumSplits(); m++)
            // System.out.print("    " + m + " " + observedCounts[j][row][m]);
            // System.out.println();
            }
        }
    }
// else
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node)

Example 29 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class EmBayesProperties method setGraph.

public final void setGraph(Graph graph) {
    if (graph == null) {
        throw new NullPointerException();
    }
    List<Node> vars = dataSet.getVariables();
    Map<String, DiscreteVariable> nodesToVars = new HashMap<>();
    for (int i = 0; i < dataSet.getNumColumns(); i++) {
        DiscreteVariable var = (DiscreteVariable) vars.get(i);
        String name = var.getName();
        Node node = new GraphNode(name);
        nodesToVars.put(node.getName(), var);
    }
    Dag dag = new Dag(graph);
    BayesPm bayesPm = new BayesPm(dag);
    List<Node> nodes = bayesPm.getDag().getNodes();
    for (Node node1 : nodes) {
        Node var = nodesToVars.get(node1.getName());
        if (var != null) {
            DiscreteVariable var2 = (DiscreteVariable) var;
            List<String> categories = var2.getCategories();
            bayesPm.setCategories(node1, categories);
        }
    }
    this.graph = graph;
    this.bayesPm = bayesPm;
    this.blankBayesIm = new MlBayesIm(bayesPm);
}
Also used : HashMap(java.util.HashMap) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 30 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class DirichletBayesIm method simulateDataHelper.

/**
 * Simulates a sample with the given sample size.
 *
 * @param sampleSize      the sample size.
 * @param randomUtil      optional random number generator to use when
 *                        creating the data
 * @param latentDataSaved true iff data for latent variables should be
 *                        saved.
 * @return the simulated sample as a DataSet.
 */
private DataSet simulateDataHelper(int sampleSize, RandomUtil randomUtil, boolean latentDataSaved) {
    int numMeasured = 0;
    int[] map = new int[nodes.length];
    List<Node> variables = new LinkedList<>();
    for (int j = 0; j < nodes.length; j++) {
        if (!latentDataSaved && nodes[j].getNodeType() != NodeType.MEASURED) {
            continue;
        }
        int numCategories = bayesPm.getNumCategories(nodes[j]);
        List<String> categories = new LinkedList<>();
        for (int k = 0; k < numCategories; k++) {
            categories.add(bayesPm.getCategory(nodes[j], k));
        }
        DiscreteVariable var = new DiscreteVariable(nodes[j].getName(), categories);
        variables.add(var);
        int index = ++numMeasured - 1;
        map[index] = j;
    }
    DataSet dataSet = new ColtDataSet(sampleSize, variables);
    constructSample(sampleSize, randomUtil, numMeasured, dataSet, map);
    return dataSet;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node)

Aggregations

DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)56 Node (edu.cmu.tetrad.graph.Node)37 DataSet (edu.cmu.tetrad.data.DataSet)18 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)11 LinkedList (java.util.LinkedList)9 Test (org.junit.Test)5 ArrayList (java.util.ArrayList)4 Dag (edu.cmu.tetrad.graph.Dag)3 NumberFormat (java.text.NumberFormat)3 Element (nu.xom.Element)3 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 List (java.util.List)2 Elements (nu.xom.Elements)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)1 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 BoxDataSet (edu.cmu.tetrad.data.BoxDataSet)1