Search in sources :

Example 16 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class SemEstimatorWrapper method serializableInstance.

// public SemEstimatorWrapper(DataWrapper dataWrapper,
// SemPmWrapper semPmWrapper,
// SemImWrapper semImWrapper,
// Parameters params) {
// if (dataWrapper == null) {
// throw new NullPointerException();
// }
// 
// if (semPmWrapper == null) {
// throw new NullPointerException();
// }
// 
// if (semImWrapper == null) {
// throw new NullPointerException();
// }
// 
// DataSet dataSet =
// (DataSet) dataWrapper.getSelectedDataModel();
// SemPm semPm = semPmWrapper.getSemPm();
// SemIm semIm = semImWrapper.getSemIm();
// 
// this.semEstimator = new SemEstimator(dataSet, semPm, getOptimizer());
// if (!degreesOfFreedomCheck(semPm)) return;
// this.semEstimator.setTrueSemIm(semIm);
// this.semEstimator.setNumRestarts(getParams().getInt("numRestarts", 1));
// this.semEstimator.estimate();
// 
// this.params = params;
// 
// log();
// }
/**
 * Generates a simple exemplar of this class to test serialization.
 *
 * @see TetradSerializableUtils
 */
public static SemEstimatorWrapper serializableInstance() {
    List<Node> variables = new LinkedList<>();
    ContinuousVariable x = new ContinuousVariable("X");
    variables.add(x);
    DataSet dataSet = new ColtDataSet(10, variables);
    for (int i = 0; i < dataSet.getNumRows(); i++) {
        for (int j = 0; j < dataSet.getNumColumns(); j++) {
            dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
        }
    }
    Dag dag = new Dag();
    dag.addNode(x);
    SemPm pm = new SemPm(dag);
    Parameters params1 = new Parameters();
    return new SemEstimatorWrapper(dataSet, pm, params1);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Parameters(edu.cmu.tetrad.util.Parameters) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Dag(edu.cmu.tetrad.graph.Dag) LinkedList(java.util.LinkedList)

Example 17 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class CovMatrixTable method setVariableName.

private void setVariableName(int index, String name) {
    List variables = getCovMatrix().getVariables();
    for (int i = 0; i < variables.size(); i++) {
        ContinuousVariable _variable = (ContinuousVariable) variables.get(i);
        if (name.equals(_variable.getName())) {
            return;
        }
    }
    ContinuousVariable variable = (ContinuousVariable) variables.get(index);
    variable.setName(name);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) List(java.util.List)

Example 18 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class IndTestMultinomialLogisticRegressionWald method isIndependentRegression.

private boolean isIndependentRegression(Node x, Node y, List<Node> z) {
    if (!variablesPerNode.containsKey(x)) {
        throw new IllegalArgumentException("Unrecogized node: " + x);
    }
    if (!variablesPerNode.containsKey(y)) {
        throw new IllegalArgumentException("Unrecogized node: " + y);
    }
    for (Node node : z) {
        if (!variablesPerNode.containsKey(node)) {
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
    }
    List<Node> regressors = new ArrayList<>();
    if (y instanceof ContinuousVariable) {
        regressors.add(internalData.getVariable(y.getName()));
    } else {
        regressors.addAll(variablesPerNode.get(y));
    }
    for (Node _z : z) {
        regressors.addAll(variablesPerNode.get(_z));
    }
    int[] _rows = getNonMissingRows(x, y, z);
    regression.setRows(_rows);
    RegressionResult result;
    try {
        result = regression.regress(x, regressors);
    } catch (Exception e) {
        return false;
    }
    double p = 1;
    if (y instanceof ContinuousVariable) {
        p = result.getP()[1];
    } else {
        for (int i = 0; i < variablesPerNode.get(y).size(); i++) {
            double val = result.getP()[1 + i];
            if (val < p)
                p = val;
        }
    }
    this.lastP = p;
    boolean indep = p > alpha;
    if (indep) {
        TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, p));
    } else {
        TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, p));
    }
    return indep;
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Node(edu.cmu.tetrad.graph.Node) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 19 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class MGM method runTests1.

private static void runTests1() {
    try {
        // DoubleMatrix2D xIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_C.txt"));
        // DoubleMatrix2D yIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_D.txt"));
        // String path = MGM.class.getResource("test_data").getPath();
        String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data";
        System.out.println(path);
        DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_C.txt").getDoubleData().toArray());
        DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_D.txt").getDoubleData().toArray());
        int[] L = new int[24];
        Node[] vars = new Node[48];
        for (int i = 0; i < 24; i++) {
            L[i] = 2;
            vars[i] = new ContinuousVariable("X" + i);
            vars[i + 24] = new DiscreteVariable("Y" + i);
        }
        double lam = .2;
        MGM model = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
        MGM model2 = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
        System.out.println("Weights: " + Arrays.toString(model.weights.toArray()));
        DoubleMatrix2D test = xIn.copy();
        DoubleMatrix2D test2 = xIn.copy();
        long t = System.currentTimeMillis();
        for (int i = 0; i < 50000; i++) {
            test2 = xIn.copy();
            test.assign(test2);
        }
        System.out.println("assign Time: " + (System.currentTimeMillis() - t));
        t = System.currentTimeMillis();
        double[][] xArr = xIn.toArray();
        for (int i = 0; i < 50000; i++) {
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            // test = DoubleFactory2D.dense.make(xArr);
            test2 = xIn.copy();
            test = test2;
        }
        System.out.println("equals Time: " + (System.currentTimeMillis() - t));
        System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
        System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
        t = System.currentTimeMillis();
        model.learnEdges(700);
        // model.learn(1e-7, 700);
        System.out.println("Orig Time: " + (System.currentTimeMillis() - t));
        System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D()));
        System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
        System.out.println("params:\n" + model.params);
        System.out.println("adjMat:\n" + model.adjMatFromMGM());
    } catch (IOException ex) {
        ex.printStackTrace();
    }
}
Also used : IOException(java.io.IOException) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D)

Example 20 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class ADTreeTest method main.

public static void main(String[] args) throws Exception {
    int columns = 40;
    int numEdges = 40;
    int rows = 500;
    List<Node> variables = new ArrayList<>();
    List<String> varNames = new ArrayList<>();
    for (int i = 0; i < columns; i++) {
        final String name = "X" + (i + 1);
        varNames.add(name);
        variables.add(new ContinuousVariable(name));
    }
    Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
    BayesPm pm = new BayesPm(graph);
    BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
    DataSet data = im.simulateData(rows, false);
    // This implementation uses a DataTable to represent the data
    // The first type parameter is the type for the variables
    // The second type parameter is the type for the values of the variables
    DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
    for (int i = 0; i < rows; i++) {
        ArrayList<Short> intArray = new ArrayList<>();
        for (int j = 0; j < columns; j++) {
            intArray.add((short) data.getInt(i, j));
        }
        dataTable.addRow(intArray);
    }
    // create the tree
    long start = System.currentTimeMillis();
    ADTree<Node, Short> adTree = new ADTree<>(dataTable);
    System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
    // the query is an arbitrary map of vars and their values
    TreeMap<Node, Short> query = new TreeMap<>();
    query.put(node(pm, "X1"), (short) 1);
    query.put(node(pm, "X5"), (short) 0);
    start = System.currentTimeMillis();
    System.out.println(String.format("Count is %d", adTree.count(query)));
    System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
    query.clear();
    query.put(node(pm, "X1"), (short) 1);
    query.put(node(pm, "X2"), (short) 1);
    query.put(node(pm, "X5"), (short) 0);
    query.put(node(pm, "X10"), (short) 1);
    start = System.currentTimeMillis();
    System.out.println(String.format("Count is %d", adTree.count(query)));
    System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) TreeMap(java.util.TreeMap) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Aggregations

ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)91 DataSet (edu.cmu.tetrad.data.DataSet)48 Node (edu.cmu.tetrad.graph.Node)46 Test (org.junit.Test)42 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)22 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)19 SemPm (edu.cmu.tetrad.sem.SemPm)18 SemIm (edu.cmu.tetrad.sem.SemIm)16 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)15 LinkedList (java.util.LinkedList)13 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)12 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 DMSearch (edu.cmu.tetrad.search.DMSearch)7 Dag (edu.cmu.tetrad.graph.Dag)6 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)5 RandomUtil (edu.cmu.tetrad.util.RandomUtil)5 ParseException (java.text.ParseException)4 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3