Search in sources :

Example 21 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class DataConvertUtils method toMixedDataBox.

public static DataModel toMixedDataBox(MixedTabularDataset mixedTabularDataset) {
    int numOfRows = mixedTabularDataset.getNumOfRows();
    MixedVarInfo[] mixedVarInfos = mixedTabularDataset.getMixedVarInfos();
    double[][] continuousData = mixedTabularDataset.getContinuousData();
    int[][] discreteData = mixedTabularDataset.getDiscreteData();
    List<Node> nodes = new LinkedList<>();
    for (MixedVarInfo mixedVarInfo : mixedVarInfos) {
        if (mixedVarInfo.isContinuous()) {
            nodes.add(new ContinuousVariable(mixedVarInfo.getName()));
        } else {
            nodes.add(new DiscreteVariable(mixedVarInfo.getName(), mixedVarInfo.getCategories()));
        }
    }
    return new BoxDataSet(new MixedDataBox(nodes, numOfRows, continuousData, discreteData), nodes);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Node(edu.cmu.tetrad.graph.Node) BoxDataSet(edu.cmu.tetrad.data.BoxDataSet) MixedDataBox(edu.cmu.tetrad.data.MixedDataBox) LinkedList(java.util.LinkedList) MixedVarInfo(edu.pitt.dbmi.data.reader.tabular.MixedVarInfo)

Example 22 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class GdistanceTest method main.

public static void main(String... args) {
    // first generate a couple random graphs
    int numVars = 16;
    int numEdges = 16;
    List<Node> vars = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph testdag1 = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    Graph testdag2 = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    // System.out.println(testdag1);
    // load the location map
    String workingDirectory = System.getProperty("user.dir");
    System.out.println(workingDirectory);
    Path mapPath = Paths.get("locationMap.txt");
    System.out.println(mapPath);
    TabularDataReader dataReaderMap = new ContinuousTabularDataFileReader(mapPath.toFile(), Delimiter.COMMA);
    try {
        DataSet locationMap = (DataSet) DataConvertUtils.toDataModel(dataReaderMap.readInData());
        // System.out.println(locationMap);
        // then compare their distance
        double xdist = 2.4;
        double ydist = 2.4;
        double zdist = 2;
        Gdistance gdist = new Gdistance(locationMap, xdist, ydist, zdist);
        List<Double> output = gdist.distances(testdag1, testdag2);
        System.out.println(output);
        PrintWriter writer = new PrintWriter("Gdistances.txt", "UTF-8");
        writer.println(output);
        writer.close();
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
}
Also used : Path(java.nio.file.Path) TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) PrintWriter(java.io.PrintWriter)

Example 23 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class HsimRobustCompare method run.

// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
    // public static void main(String[] args) {
    // first generate the data
    RandomUtil.getInstance().setSeed(1450184147770L);
    // '\t';
    char delimiter = ',';
    final int numEdges = (int) (numVars * edgesPerNode);
    List<Node> vars = new ArrayList<>();
    double[] oErrors = new double[5];
    double[] hsimErrors = new double[5];
    double[] simErrors = new double[5];
    List<double[]> output = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    BayesPm bayesPm = new BayesPm(odag, 2, 2);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    // oData is the original data set, and odag is the original dag.
    DataSet oData = bayesIm.simulateData(numCases, false);
    // System.out.println(oData);
    // System.out.println(odag);
    // then run FGES
    BDeuScore oscore = new BDeuScore(oData);
    Fges fges = new Fges(oscore);
    fges.setVerbose(false);
    fges.setNumPatternsToStore(0);
    fges.setPenaltyDiscount(penaltyDiscount);
    Graph oGraphOut = fges.search();
    if (verbose)
        System.out.println(oGraphOut);
    // calculate FGES errors
    oErrors = new double[5];
    oErrors = HsimUtils.errorEval(oGraphOut, odag);
    if (verbose)
        System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    // create various simulated data sets
    // //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
    PatternToDag pickdag = new PatternToDag(oGraphOut);
    Graph fgesDag = pickdag.patternToDagMeek();
    Dag fgesdag2 = new Dag(fgesDag);
    BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
    DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
    DirichletEstimator simEstimator = new DirichletEstimator();
    DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
    DataSet simData = fittedIM.simulateData(numCases, false);
    // //next let's do a schedule of small hsims
    HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
    hsimErrors = study.run(resimSize, repeat);
    // calculate errors for all simulated output graphs
    // //full simulation errors first
    BDeuScore simscore = new BDeuScore(simData);
    Fges simfges = new Fges(simscore);
    simfges.setVerbose(false);
    simfges.setNumPatternsToStore(0);
    simfges.setPenaltyDiscount(penaltyDiscount);
    Graph simGraphOut = simfges.search();
    // simErrors = new double[5];
    simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
    // first, let's just see what the errors are.
    if (verbose)
        System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    if (verbose)
        System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
    if (verbose)
        System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
    // then, let's try to squeeze these numbers down into something more tractable.
    // double[] ErrorDifferenceDifferences;
    // ErrorDifferenceDifferences = new double[5];
    // ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
    // ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
    // ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
    // ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
    // ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
    // System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
    output.add(oErrors);
    output.add(simErrors);
    output.add(hsimErrors);
    return output;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ArrayList(java.util.ArrayList) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Example 24 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestLogisticRegression method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
    System.out.println(graph);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateDataRecursive(1000, false);
    Node x1 = data.getVariable("X1");
    Node x2 = data.getVariable("X2");
    Node x3 = data.getVariable("X3");
    Node x4 = data.getVariable("X4");
    Node x5 = data.getVariable("X5");
    Discretizer discretizer = new Discretizer(data);
    discretizer.equalCounts(x1, 2);
    DataSet d2 = discretizer.discretize();
    LogisticRegression regression = new LogisticRegression(d2);
    List<Node> regressors = new ArrayList<>();
    regressors.add(x2);
    regressors.add(x3);
    regressors.add(x4);
    regressors.add(x5);
    DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
    regression.regress(x1b, regressors);
    System.out.println(regression);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 25 with ContinuousVariable

use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.

the class TestDeltaTetradTest method makePm.

private SemPm makePm() {
    List<Node> variableNodes = new ArrayList<>();
    ContinuousVariable x1 = new ContinuousVariable("X1");
    ContinuousVariable x2 = new ContinuousVariable("X2");
    ContinuousVariable x3 = new ContinuousVariable("X3");
    ContinuousVariable x4 = new ContinuousVariable("X4");
    ContinuousVariable x5 = new ContinuousVariable("X5");
    variableNodes.add(x1);
    variableNodes.add(x2);
    variableNodes.add(x3);
    variableNodes.add(x4);
    variableNodes.add(x5);
    Graph _graph = new EdgeListGraph(variableNodes);
    SemGraph graph = new SemGraph(_graph);
    graph.addDirectedEdge(x5, x1);
    graph.addDirectedEdge(x5, x2);
    graph.addDirectedEdge(x5, x3);
    graph.addDirectedEdge(x5, x4);
    return new SemPm(graph);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm)

Aggregations

ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)91 DataSet (edu.cmu.tetrad.data.DataSet)48 Node (edu.cmu.tetrad.graph.Node)46 Test (org.junit.Test)42 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)22 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)19 SemPm (edu.cmu.tetrad.sem.SemPm)18 SemIm (edu.cmu.tetrad.sem.SemIm)16 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)15 LinkedList (java.util.LinkedList)13 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)12 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 DMSearch (edu.cmu.tetrad.search.DMSearch)7 Dag (edu.cmu.tetrad.graph.Dag)6 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)5 RandomUtil (edu.cmu.tetrad.util.RandomUtil)5 ParseException (java.text.ParseException)4 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3