Search in sources :

Example 36 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class AdLeafTree method getCellLeaves.

/**
 * Finds the set of indices into the leaves of the tree for the given variables.
 * Counts are the sizes of the index sets.
 *
 * @param A A list of discrete variables.
 * @return The list of index sets of the first variable varied by the second variable,
 * and so on, to the last variable.
 */
public List<List<List<Integer>>> getCellLeaves(List<DiscreteVariable> A, DiscreteVariable B) {
    Collections.sort(A, new Comparator<DiscreteVariable>() {

        @Override
        public int compare(DiscreteVariable o1, DiscreteVariable o2) {
            return Integer.compare(nodesHash.get(o1), nodesHash.get(o2));
        }
    });
    if (baseCase == null) {
        Vary vary = new Vary();
        this.baseCase = new ArrayList<>();
        baseCase.add(vary);
    }
    List<Vary> varies = baseCase;
    for (DiscreteVariable v : A) {
        varies = getVaries(varies, nodesHash.get(v));
    }
    List<List<List<Integer>>> rows = new ArrayList<>();
    for (Vary vary : varies) {
        for (int i = 0; i < vary.getNumCategories(); i++) {
            Vary subvary = vary.getSubvary(nodesHash.get(B), i);
            rows.add(subvary.getRows());
        }
    }
    return rows;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 37 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class BayesUpdaterClassifier method setTarget.

// ==========================PUBLIC METHODS========================//
public void setTarget(String target, int targetCategory) {
    // Find the target variable using its name.
    DiscreteVariable targetVariable = null;
    for (int j = 0; j < getBayesImVars().size(); j++) {
        DiscreteVariable dv = (DiscreteVariable) getBayesImVars().get(j);
        if (dv.getName().equals(target)) {
            targetVariable = dv;
            break;
        }
    }
    if (targetVariable == null) {
        throw new IllegalArgumentException("Not an available target: " + target);
    }
    this.targetVariable = targetVariable;
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 38 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class IndTestMultinomialLogisticRegressionWald method expandVariable.

private List<Node> expandVariable(DataSet dataSet, Node node) {
    if (node instanceof ContinuousVariable) {
        return Collections.singletonList(node);
    }
    if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
        return Collections.singletonList(node);
    }
    if (!(node instanceof DiscreteVariable)) {
        throw new IllegalArgumentException();
    }
    List<String> varCats = new ArrayList<>(((DiscreteVariable) node).getCategories());
    varCats.remove(0);
    List<Node> variables = new ArrayList<>();
    for (String cat : varCats) {
        Node newVar;
        do {
            String newVarName = node.getName() + "MULTINOM" + "." + cat;
            newVar = new DiscreteVariable(newVarName, 2);
        } while (dataSet.getVariable(newVar.getName()) != null);
        variables.add(newVar);
        dataSet.addVariable(newVar);
        int newVarIndex = dataSet.getColumn(newVar);
        int numCases = dataSet.getNumRows();
        for (int l = 0; l < numCases; l++) {
            Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
            int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());
            if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
                dataSet.setInt(l, newVarIndex, 1);
            else
                dataSet.setInt(l, newVarIndex, 0);
        }
    }
    return variables;
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node)

Example 39 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class MNLRLikelihood method getLik.

public double getLik(int child_index, int[] parents) {
    double lik = 0;
    Node c = variables.get(child_index);
    List<ContinuousVariable> continuous_parents = new ArrayList<>();
    List<DiscreteVariable> discrete_parents = new ArrayList<>();
    for (int p : parents) {
        Node parent = variables.get(p);
        if (parent instanceof ContinuousVariable) {
            continuous_parents.add((ContinuousVariable) parent);
        } else {
            discrete_parents.add((DiscreteVariable) parent);
        }
    }
    int p = continuous_parents.size();
    List<List<Integer>> cells = adTree.getCellLeaves(discrete_parents);
    // List<List<Integer>> cells = partition(discrete_parents);
    int[] continuousCols = new int[p];
    for (int j = 0; j < p; j++) continuousCols[j] = nodesHash.get(continuous_parents.get(j));
    for (List<Integer> cell : cells) {
        int r = cell.size();
        if (r > 1) {
            double[] mean = new double[p];
            double[] var = new double[p];
            for (int i = 0; i < p; i++) {
                for (int j = 0; j < r; j++) {
                    mean[i] += continuousData[continuousCols[i]][cell.get(j)];
                    var[i] += Math.pow(continuousData[continuousCols[i]][cell.get(j)], 2);
                }
                mean[i] /= r;
                var[i] /= r;
                var[i] -= Math.pow(mean[i], 2);
                var[i] = Math.sqrt(var[i]);
                if (Double.isNaN(var[i])) {
                    System.out.println(var[i]);
                }
            }
            int degree = fDegree;
            if (fDegree < 1) {
                degree = (int) Math.floor(Math.log(r));
            }
            TetradMatrix subset = new TetradMatrix(r, p * degree + 1);
            for (int i = 0; i < r; i++) {
                subset.set(i, p * degree, 1);
                for (int j = 0; j < p; j++) {
                    for (int d = 0; d < degree; d++) {
                        subset.set(i, p * d + j, Math.pow((continuousData[continuousCols[j]][cell.get(i)] - mean[j]) / var[j], d + 1));
                    }
                }
            }
            if (c instanceof ContinuousVariable) {
                TetradVector target = new TetradVector(r);
                for (int i = 0; i < r; i++) {
                    target.set(i, continuousData[child_index][cell.get(i)]);
                }
                lik += multipleRegression(target, subset);
            } else {
                ArrayList<Integer> temp = new ArrayList<>();
                TetradMatrix target = new TetradMatrix(r, ((DiscreteVariable) c).getNumCategories());
                for (int i = 0; i < r; i++) {
                    for (int j = 0; j < ((DiscreteVariable) c).getNumCategories(); j++) {
                        target.set(i, j, -1);
                    }
                    target.set(i, discreteData[child_index][cell.get(i)], 1);
                }
                lik += MultinomialLogisticRegression(target, subset);
            }
        }
    }
    return lik;
}
Also used : Node(edu.cmu.tetrad.graph.Node) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) TetradVector(edu.cmu.tetrad.util.TetradVector) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable)

Example 40 with DiscreteVariable

use of edu.cmu.tetrad.data.DiscreteVariable in project tetrad by cmu-phil.

the class MbClassify method classify.

// ============================PUBLIC METHODS=========================//
/**
 * Classifies the test data by Bayesian updating. The procedure is as follows. First, MBFS is run on the training
 * data to estimate an MB pattern. Bidirected edges are removed; an MB DAG G is selected from the pattern that
 * remains. Second, a Bayes model B is estimated using this G and the training data. Third, for each case in the
 * test data, the marginal for the target variable in B is calculated conditioning on values of the other varialbes
 * in B in the test data; these are reported as classifications. Estimation of B is done using a Dirichlet
 * estimator, with a symmetric prior, with the given alpha value. Updating is done using a row-summing exact
 * updater.
 * <p>
 * One consequence of using the row-summing exact updater is that classification will be fast except for cases in
 * which there are lots of missing values. The reason for this is that for such cases the number of rows that need
 * to be summed over will be exponential in the number of missing values for that case. Hence the parameter for max
 * num missing values. A good default for this is like 5. Any test case with more than that number of missing values
 * will be skipped.
 *
 * @return The classifications.
 */
public int[] classify() {
    IndependenceTest indTest = new IndTestChiSquare(train, alpha);
    Mbfs search = new Mbfs(indTest, depth);
    search.setDepth(depth);
    // Hiton search = new Hiton(indTest, depth);
    // Mmmb search = new Mmmb(indTest, depth);
    List<Node> mbPlusTarget = search.findMb(target);
    mbPlusTarget.add(train.getVariable(target));
    DataSet subset = train.subsetColumns(mbPlusTarget);
    System.out.println("subset vars = " + subset.getVariables());
    Pc patternSearch = new Pc(new IndTestChiSquare(subset, 0.05));
    // patternSearch.setMaxIndegree(depth);
    Graph mbPattern = patternSearch.search();
    // MbFanSearch search = new MbFanSearch(indTest, depth);
    // Graph mbPattern = search.search(target);
    TetradLogger.getInstance().log("details", "Pattern = " + mbPattern);
    MbUtils.trimToMbNodes(mbPattern, train.getVariable(target), true);
    TetradLogger.getInstance().log("details", "Trimmed pattern = " + mbPattern);
    // Removing bidirected edges from the pattern before selecting a DAG.                                   4
    for (Edge edge : mbPattern.getEdges()) {
        if (Edges.isBidirectedEdge(edge)) {
            mbPattern.removeEdge(edge);
        }
    }
    Graph selectedDag = MbUtils.getOneMbDag(mbPattern);
    TetradLogger.getInstance().log("details", "Selected DAG = " + selectedDag);
    TetradLogger.getInstance().log("details", "Vars = " + selectedDag.getNodes());
    TetradLogger.getInstance().log("details", "\nClassification using selected MB DAG:");
    NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    List<Node> mbNodes = selectedDag.getNodes();
    // The Markov blanket nodes will correspond to a subset of the variables
    // in the training dataset.  Find the subset dataset.
    DataSet trainDataSubset = train.subsetColumns(mbNodes);
    // To create a Bayes net for the Markov blanket we need the DAG.
    BayesPm bayesPm = new BayesPm(selectedDag);
    // To parameterize the Bayes net we need the number of values
    // of each variable.
    List varsTrain = trainDataSubset.getVariables();
    for (int i1 = 0; i1 < varsTrain.size(); i1++) {
        DiscreteVariable trainingVar = (DiscreteVariable) varsTrain.get(i1);
        bayesPm.setCategories(mbNodes.get(i1), trainingVar.getCategories());
    }
    // Create an updater for the instantiated Bayes net.
    TetradLogger.getInstance().log("info", "Estimating Bayes net; please wait...");
    DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPm, this.prior);
    BayesIm bayesIm = DirichletEstimator.estimate(prior, trainDataSubset);
    RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm);
    // The subset dataset of the dataset to be classified containing
    // the variables in the Markov blanket.
    DataSet testSubset = this.test.subsetColumns(mbNodes);
    // Get the raw data from the dataset to be classified, the number
    // of variables, and the number of cases.
    int numCases = testSubset.getNumRows();
    int[] estimatedCategories = new int[numCases];
    Arrays.fill(estimatedCategories, -1);
    // The variables in the dataset.
    List<Node> varsClassify = testSubset.getVariables();
    // of the crosstabulation array.
    for (int k = 0; k < numCases; k++) {
        // Create an Evidence instance for the instantiated Bayes net
        // which will allow that updating.
        Proposition proposition = Proposition.tautology(bayesIm);
        // Restrict all other variables to their observed values in
        // this case.
        int numMissing = 0;
        for (int testIndex = 0; testIndex < varsClassify.size(); testIndex++) {
            DiscreteVariable var = (DiscreteVariable) varsClassify.get(testIndex);
            // If it's the target, ignore it.
            if (var.equals(targetVariable)) {
                continue;
            }
            int trainIndex = proposition.getNodeIndex(var.getName());
            // If it's not in the train subset, ignore it.
            if (trainIndex == -99) {
                continue;
            }
            int testValue = testSubset.getInt(k, testIndex);
            if (testValue == -99) {
                numMissing++;
            } else {
                proposition.setCategory(trainIndex, testValue);
            }
        }
        if (numMissing > this.maxMissing) {
            TetradLogger.getInstance().log("details", "classification(" + k + ") = " + "not done since number of missing values too high " + "(" + numMissing + ").");
            continue;
        }
        Evidence evidence = Evidence.tautology(bayesIm);
        evidence.getProposition().restrictToProposition(proposition);
        updater.setEvidence(evidence);
        // for each possible value of target compute its probability in
        // the updated Bayes net.  Select the value with the highest
        // probability as the estimated getValue.
        int targetIndex = proposition.getNodeIndex(targetVariable.getName());
        // Straw man values--to be replaced.
        double highestProb = -0.1;
        int _category = -1;
        for (int category = 0; category < targetVariable.getNumCategories(); category++) {
            double marginal = updater.getMarginal(targetIndex, category);
            if (marginal > highestProb) {
                highestProb = marginal;
                _category = category;
            }
        }
        // training dataset.  If that happens skip the case.
        if (_category < 0) {
            System.out.println("classification(" + k + ") is undefined " + "(undefined marginals).");
            continue;
        }
        String estimatedCategory = targetVariable.getCategories().get(_category);
        TetradLogger.getInstance().log("details", "classification(" + k + ") = " + estimatedCategory);
        estimatedCategories[k] = _category;
    }
    // Create a crosstabulation table to store the coefs of observed
    // versus estimated occurrences of each value of the target variable.
    int targetIndex = varsClassify.indexOf(targetVariable);
    int numCategories = targetVariable.getNumCategories();
    int[][] crossTabs = new int[numCategories][numCategories];
    // Will count the number of cases where the target variable
    // is correctly classified.
    int numberCorrect = 0;
    int numberCounted = 0;
    for (int k = 0; k < numCases; k++) {
        int estimatedCategory = estimatedCategories[k];
        int observedValue = testSubset.getInt(k, targetIndex);
        if (estimatedCategory < 0) {
            continue;
        }
        crossTabs[observedValue][estimatedCategory]++;
        numberCounted++;
        if (observedValue == estimatedCategory) {
            numberCorrect++;
        }
    }
    double percentCorrect1 = 100.0 * ((double) numberCorrect) / ((double) numberCounted);
    // Print the cross classification.
    TetradLogger.getInstance().log("details", "");
    TetradLogger.getInstance().log("details", "\t\t\tEstimated\t");
    TetradLogger.getInstance().log("details", "Observed\t");
    StringBuilder buf0 = new StringBuilder();
    buf0.append("\t");
    for (int m = 0; m < numCategories; m++) {
        buf0.append(targetVariable.getCategory(m)).append("\t");
    }
    TetradLogger.getInstance().log("details", buf0.toString());
    for (int k = 0; k < numCategories; k++) {
        StringBuilder buf = new StringBuilder();
        buf.append(targetVariable.getCategory(k)).append("\t");
        for (int m = 0; m < numCategories; m++) buf.append(crossTabs[k][m]).append("\t");
        TetradLogger.getInstance().log("details", buf.toString());
    }
    TetradLogger.getInstance().log("details", "");
    TetradLogger.getInstance().log("details", "Number correct = " + numberCorrect);
    TetradLogger.getInstance().log("details", "Number counted = " + numberCounted);
    TetradLogger.getInstance().log("details", "Percent correct = " + nf.format(percentCorrect1) + "%");
    crossTabulation = crossTabs;
    percentCorrect = percentCorrect1;
    return estimatedCategories;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) List(java.util.List) NumberFormat(java.text.NumberFormat)

Aggregations

DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)56 Node (edu.cmu.tetrad.graph.Node)37 DataSet (edu.cmu.tetrad.data.DataSet)18 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)11 LinkedList (java.util.LinkedList)9 Test (org.junit.Test)5 ArrayList (java.util.ArrayList)4 Dag (edu.cmu.tetrad.graph.Dag)3 NumberFormat (java.text.NumberFormat)3 Element (nu.xom.Element)3 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Graph (edu.cmu.tetrad.graph.Graph)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 List (java.util.List)2 Elements (nu.xom.Elements)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)1 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)1 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 BoxDataSet (edu.cmu.tetrad.data.BoxDataSet)1