Search in sources :

Example 11 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class TestProposition method sampleBayesIm2.

private BayesIm sampleBayesIm2() {
    Node a = new GraphNode("a");
    Node b = new GraphNode("b");
    Node c = new GraphNode("c");
    Dag graph;
    graph = new Dag();
    graph.addNode(a);
    graph.addNode(b);
    graph.addNode(c);
    graph.addDirectedEdge(a, b);
    graph.addDirectedEdge(a, c);
    graph.addDirectedEdge(b, c);
    BayesPm bayesPm = new BayesPm(graph);
    bayesPm.setNumCategories(b, 3);
    BayesIm bayesIm1 = new MlBayesIm(bayesPm);
    bayesIm1.setProbability(0, 0, 0, .3);
    bayesIm1.setProbability(0, 0, 1, .7);
    bayesIm1.setProbability(1, 0, 0, .3);
    bayesIm1.setProbability(1, 0, 1, .4);
    bayesIm1.setProbability(1, 0, 2, .3);
    bayesIm1.setProbability(1, 1, 0, .6);
    bayesIm1.setProbability(1, 1, 1, .1);
    bayesIm1.setProbability(1, 1, 2, .3);
    bayesIm1.setProbability(2, 0, 0, .9);
    bayesIm1.setProbability(2, 0, 1, .1);
    bayesIm1.setProbability(2, 1, 0, .1);
    bayesIm1.setProbability(2, 1, 1, .9);
    bayesIm1.setProbability(2, 2, 0, .5);
    bayesIm1.setProbability(2, 2, 1, .5);
    bayesIm1.setProbability(2, 3, 0, .2);
    bayesIm1.setProbability(2, 3, 1, .8);
    bayesIm1.setProbability(2, 4, 0, .6);
    bayesIm1.setProbability(2, 4, 1, .4);
    bayesIm1.setProbability(2, 5, 0, .7);
    bayesIm1.setProbability(2, 5, 1, .3);
    return bayesIm1;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 12 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class ConditionalGaussianSimulation method simulate.

private DataSet simulate(Graph G, Parameters parameters) {
    HashMap<String, Integer> nd = new HashMap<>();
    List<Node> nodes = G.getNodes();
    Collections.shuffle(nodes);
    if (this.shuffledOrder == null) {
        List<Node> shuffledNodes = new ArrayList<>(nodes);
        Collections.shuffle(shuffledNodes);
        this.shuffledOrder = shuffledNodes;
    }
    for (int i = 0; i < nodes.size(); i++) {
        if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
            final int minNumCategories = parameters.getInt("minCategories");
            final int maxNumCategories = parameters.getInt("maxCategories");
            final int value = pickNumCategories(minNumCategories, maxNumCategories);
            nd.put(shuffledOrder.get(i).getName(), value);
        } else {
            nd.put(shuffledOrder.get(i).getName(), 0);
        }
    }
    G = makeMixedGraph(G, nd);
    nodes = G.getNodes();
    DataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes);
    List<Node> X = new ArrayList<>();
    List<Node> A = new ArrayList<>();
    for (Node node : G.getNodes()) {
        if (node instanceof ContinuousVariable) {
            X.add(node);
        } else {
            A.add(node);
        }
    }
    Graph AG = G.subgraph(A);
    Graph XG = G.subgraph(X);
    Map<ContinuousVariable, DiscreteVariable> erstatzNodes = new HashMap<>();
    Map<String, ContinuousVariable> erstatzNodesReverse = new HashMap<>();
    for (Node y : A) {
        for (Node x : G.getParents(y)) {
            if (x instanceof ContinuousVariable) {
                DiscreteVariable ersatz = erstatzNodes.get(x);
                if (ersatz == null) {
                    ersatz = new DiscreteVariable("Ersatz_" + x.getName(), RandomUtil.getInstance().nextInt(3) + 2);
                    erstatzNodes.put((ContinuousVariable) x, ersatz);
                    erstatzNodesReverse.put(ersatz.getName(), (ContinuousVariable) x);
                    AG.addNode(ersatz);
                }
                AG.addDirectedEdge(ersatz, y);
            }
        }
    }
    BayesPm bayesPm = new BayesPm(AG);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    SemPm semPm = new SemPm(XG);
    Map<Combination, Double> paramValues = new HashMap<>();
    List<Node> tierOrdering = G.getCausalOrdering();
    int[] tiers = new int[tierOrdering.size()];
    for (int t = 0; t < tierOrdering.size(); t++) {
        tiers[t] = nodes.indexOf(tierOrdering.get(t));
    }
    Map<Integer, double[]> breakpointsMap = new HashMap<>();
    for (int mixedIndex : tiers) {
        for (int i = 0; i < parameters.getInt("sampleSize"); i++) {
            if (nodes.get(mixedIndex) instanceof DiscreteVariable) {
                int bayesIndex = bayesIm.getNodeIndex(nodes.get(mixedIndex));
                int[] bayesParents = bayesIm.getParents(bayesIndex);
                int[] parentValues = new int[bayesParents.length];
                for (int k = 0; k < parentValues.length; k++) {
                    int bayesParentColumn = bayesParents[k];
                    Node bayesParent = bayesIm.getVariables().get(bayesParentColumn);
                    DiscreteVariable _parent = (DiscreteVariable) bayesParent;
                    int value;
                    ContinuousVariable orig = erstatzNodesReverse.get(_parent.getName());
                    if (orig != null) {
                        int mixedParentColumn = mixedData.getColumn(orig);
                        double d = mixedData.getDouble(i, mixedParentColumn);
                        double[] breakpoints = breakpointsMap.get(mixedParentColumn);
                        if (breakpoints == null) {
                            breakpoints = getBreakpoints(mixedData, _parent, mixedParentColumn);
                            breakpointsMap.put(mixedParentColumn, breakpoints);
                        }
                        value = breakpoints.length;
                        for (int j = 0; j < breakpoints.length; j++) {
                            if (d < breakpoints[j]) {
                                value = j;
                                break;
                            }
                        }
                    } else {
                        int mixedColumn = mixedData.getColumn(bayesParent);
                        value = mixedData.getInt(i, mixedColumn);
                    }
                    parentValues[k] = value;
                }
                int rowIndex = bayesIm.getRowIndex(bayesIndex, parentValues);
                double sum = 0.0;
                double r = RandomUtil.getInstance().nextDouble();
                mixedData.setInt(i, mixedIndex, 0);
                for (int k = 0; k < bayesIm.getNumColumns(bayesIndex); k++) {
                    double probability = bayesIm.getProbability(bayesIndex, rowIndex, k);
                    sum += probability;
                    if (sum >= r) {
                        mixedData.setInt(i, mixedIndex, k);
                        break;
                    }
                }
            } else {
                Node y = nodes.get(mixedIndex);
                Set<DiscreteVariable> discreteParents = new HashSet<>();
                Set<ContinuousVariable> continuousParents = new HashSet<>();
                for (Node node : G.getParents(y)) {
                    if (node instanceof DiscreteVariable) {
                        discreteParents.add((DiscreteVariable) node);
                    } else {
                        continuousParents.add((ContinuousVariable) node);
                    }
                }
                Parameter varParam = semPm.getParameter(y, y);
                Parameter muParam = semPm.getMeanParameter(y);
                Combination varComb = new Combination(varParam);
                Combination muComb = new Combination(muParam);
                for (DiscreteVariable v : discreteParents) {
                    varComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                    muComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                }
                double value = RandomUtil.getInstance().nextNormal(0, getParamValue(varComb, paramValues));
                for (Node x : continuousParents) {
                    Parameter coefParam = semPm.getParameter(x, y);
                    Combination coefComb = new Combination(coefParam);
                    for (DiscreteVariable v : discreteParents) {
                        coefComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                    }
                    int parent = nodes.indexOf(x);
                    double parentValue = mixedData.getDouble(i, parent);
                    double parentCoef = getParamValue(coefComb, paramValues);
                    value += parentValue * parentCoef;
                }
                value += getParamValue(muComb, paramValues);
                mixedData.setDouble(i, mixedIndex, value);
            }
        }
    }
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    return saveLatentVars ? mixedData : DataUtils.restrictToMeasured(mixedData);
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 13 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class Comparison2 method compare.

/**
 * Simulates data from model parameterizing the given DAG, and runs the
 * algorithm on that data, printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet = null;
    Graph trueDag = null;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.isDataFromFile()) {
        /**
         * Set path to the data directory *
         */
        String path = "/Users/dmalinsky/Documents/research/data/danexamples";
        File dir = new File(path);
        File[] files = dir.listFiles();
        if (files == null) {
            throw new NullPointerException("No files in " + path);
        }
        for (File file : files) {
            if (file.getName().startsWith("graph") && file.getName().contains(String.valueOf(params.getGraphNum())) && file.getName().endsWith(".g.txt")) {
                params.setGraphFile(file.getName());
                trueDag = GraphUtils.loadGraphTxt(file);
                break;
            }
        }
        String trialGraph = String.valueOf(params.getGraphNum()).concat("-").concat(String.valueOf(params.getTrial())).concat(".dat.txt");
        for (File file : files) {
            if (file.getName().startsWith("graph") && file.getName().endsWith(trialGraph)) {
                Path dataFile = Paths.get(path.concat("/").concat(file.getName()));
                Delimiter delimiter = Delimiter.TAB;
                if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
                    try {
                        TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), delimiter);
                        dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    params.setDataFile(file.getName());
                    break;
                } else {
                    try {
                        TabularDataReader dataReader = new VerticalDiscreteTabularDataReader(dataFile.toFile(), delimiter);
                        dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    params.setDataFile(file.getName());
                    break;
                }
            }
        }
        System.out.println("current graph file = " + params.getGraphFile());
        System.out.println("current data set file = " + params.getDataFile());
    }
    if (params.isNoData()) {
        List<Node> nodes = new ArrayList<>();
        for (int i = 0; i < params.getNumVars(); i++) {
            nodes.add(new ContinuousVariable("X" + (i + 1)));
        }
        trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
        /**
         * added 5.25.16 for tsFCI *
         */
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
            System.out.println("Creating Time Lag Graph : " + trueDag);
        }
        /**
         * ************************
         */
        test = new IndTestDSep(trueDag);
        score = new GraphScore(trueDag);
        if (params.getAlgorithm() == null) {
            throw new IllegalArgumentException("Algorithm not set.");
        }
        long time1 = System.currentTimeMillis();
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Pc search = new Pc(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Cpc search = new Cpc(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            PcLocal search = new PcLocal(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            PcStableMax search = new PcStableMax(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
            if (score == null) {
                throw new IllegalArgumentException("Score not set.");
            }
            Fges search = new Fges(score);
            // search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Fci search = new Fci(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(new DagToPag(trueDag).convert());
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            GFci search = new GFci(test, score);
            result.setResultGraph(search.search());
            result.setCorrectResult(new DagToPag(trueDag).convert());
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            TsFci search = new TsFci(test);
            IKnowledge knowledge = getKnowledge(trueDag);
            search.setKnowledge(knowledge);
            result.setResultGraph(search.search());
            result.setCorrectResult(new TsDagToPag(trueDag).convert());
            System.out.println("Correct result for trial = " + result.getCorrectResult());
            System.out.println("Search result for trial = " + result.getResultGraph());
        } else {
            throw new IllegalArgumentException("Unrecognized algorithm.");
        }
        long time2 = System.currentTimeMillis();
        long elapsed = time2 - time1;
        result.setElapsed(elapsed);
        result.setTrueDag(trueDag);
        return result;
    } else if (params.getDataFile() != null) {
        // dataSet = loadDataFile(params.getDataFile());
        System.out.println("Using data from file... ");
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        } else {
            System.out.println("Using graph from file... ");
        // trueDag = GraphUtils.loadGraph(File params.getGraphFile());
        }
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
                trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
                System.out.println("Creating Time Lag Graph : " + trueDag);
            }
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                sim.setCoefRange(0.20, 0.50);
            }
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                // //                    System.out.println("Coefs matrix : " + sim.getCoefs());
                // System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
                // //                    System.out.println("dim = " + sim.getCoefs()[1][1]);
                // boolean isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
                // //this TetradMatrix needs to be the matrix of coefficients from the SEM!
                // if (!isStableTetradMatrix) {
                // System.out.println("%%%%%%%%%% WARNING %%%%%%%%% not a stable set of eigenvalues for data generation");
                // System.out.println("Skipping this attempt!");
                // sim.setCoefRange(0.2, 0.5);
                // dataSet = sim.simulateDataAcyclic(params.getSampleSize());
                // }
                // 
                // /***************************/
                boolean isStableTetradMatrix;
                int attempt = 1;
                int tierSize = params.getNumVars();
                int[] sub = new int[tierSize];
                int[] sub2 = new int[tierSize];
                for (int i = 0; i < tierSize; i++) {
                    sub[i] = i;
                    sub2[i] = tierSize + i;
                }
                do {
                    dataSet = sim.simulateDataFisher(params.getSampleSize());
                    // System.out.println("Variable Nodes : " + sim.getVariableNodes());
                    // System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
                    TetradMatrix coefMat = new TetradMatrix(sim.getCoefficientMatrix());
                    TetradMatrix B = coefMat.getSelection(sub, sub);
                    TetradMatrix Gamma1 = coefMat.getSelection(sub2, sub);
                    TetradMatrix Gamma0 = TetradMatrix.identity(tierSize).minus(B);
                    TetradMatrix A1 = Gamma0.inverse().times(Gamma1);
                    // TetradMatrix B2 = coefMat.getSelection(sub2, sub2);
                    // System.out.println("B matrix : " + B);
                    // System.out.println("B2 matrix : " + B2);
                    // System.out.println("Gamma1 matrix : " + Gamma1);
                    // isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
                    isStableTetradMatrix = TimeSeriesUtils.allEigenvaluesAreSmallerThanOneInModulus(A1);
                    System.out.println("isStableTetradMatrix? : " + isStableTetradMatrix);
                    attempt++;
                } while ((!isStableTetradMatrix) && attempt <= 5);
                if (!isStableTetradMatrix) {
                    System.out.println("%%%%%%%%%% WARNING %%%%%%%% not a stable coefficient matrix, forcing coefs to [0.15,0.3]");
                    System.out.println("Made " + (attempt - 1) + " attempts to get stable matrix.");
                    sim.setCoefRange(0.15, 0.3);
                    dataSet = sim.simulateDataFisher(params.getSampleSize());
                } else {
                    System.out.println("Coefficient matrix is stable.");
                }
            }
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null) {
            throw new IllegalArgumentException("Score not set.");
        }
        Fges search = new Fges(score);
        // search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        TsFci search = new TsFci(test);
        IKnowledge knowledge = getKnowledge(trueDag);
        search.setKnowledge(knowledge);
        result.setResultGraph(search.search());
        result.setCorrectResult(new TsDagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) VerticalDiscreteTabularDataReader(edu.pitt.dbmi.data.reader.tabular.VerticalDiscreteTabularDataReader) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) VerticalDiscreteTabularDataReader(edu.pitt.dbmi.data.reader.tabular.VerticalDiscreteTabularDataReader) Path(java.nio.file.Path) Delimiter(edu.pitt.dbmi.data.Delimiter) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 14 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class PerformanceTests method testFgesMb.

private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
    double penaltyDiscount = 4.0;
    int structurePrior = 10;
    int samplePrior = 10;
    int maxIndegree = -1;
    // boolean faithfulness = false;
    List<int[][]> allCounts = new ArrayList<>();
    List<double[]> comparisons = new ArrayList<>();
    List<Double> degrees = new ArrayList<>();
    List<Long> elapsedTimes = new ArrayList<>();
    System.out.println("Making dag");
    Graph dag = makeDag(numVars, edgeFactor);
    System.out.println(new Date());
    System.out.println("Calculating pattern for DAG");
    Graph pattern = SearchGraphUtils.patternForDag(dag);
    int[] tiers = new int[dag.getNumNodes()];
    for (int i = 0; i < dag.getNumNodes(); i++) {
        tiers[i] = i;
    }
    System.out.println("Graph done");
    long time1 = System.currentTimeMillis();
    out.println("Graph done");
    System.out.println(new Date());
    System.out.println("Starting simulation");
    Graph estPattern;
    long elapsed;
    FgesMb fges;
    List<Node> vars;
    if (continuous) {
        init(new File("FgesMb.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Penalty discount = " + penaltyDiscount);
        out.println("Depth = " + maxIndegree);
        out.println();
        out.println(new Date());
        vars = dag.getNodes();
        LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
        simulator.setVerbose(false);
        simulator.setOut(out);
        DataSet data = simulator.simulateDataFisher(numCases);
        System.out.println("Finishing simulation");
        System.out.println(new Date());
        long time2 = System.currentTimeMillis();
        out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
        System.out.println(new Date());
        System.out.println("Making covariance matrix");
        long time3 = System.currentTimeMillis();
        ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
        System.out.println("Covariance matrix done");
        out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
        SemBicScore score = new SemBicScore(cov);
        score.setPenaltyDiscount(penaltyDiscount);
        System.out.println(new Date());
        System.out.println("\nStarting FGES-MB");
        fges = new FgesMb(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setOut(System.out);
        // fges.setHeuristicSpeedup(faithfulness);
        fges.setMaxIndegree(maxIndegree);
        fges.setCycleBound(-1);
    } else {
        init(new File("FgesMb.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Sample prior = " + samplePrior);
        out.println("Structure prior = " + structurePrior);
        out.println("Depth = " + maxIndegree);
        out.println();
        out.println(new Date());
        BayesPm pm = new BayesPm(dag, 3, 3);
        MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
        DataSet data = im.simulateData(numCases, false, tiers);
        vars = data.getVariables();
        pattern = GraphUtils.replaceNodes(pattern, vars);
        System.out.println("Finishing simulation");
        long time2 = System.currentTimeMillis();
        out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
        long time3 = System.currentTimeMillis();
        BDeuScore score = new BDeuScore(data);
        score.setStructurePrior(structurePrior);
        score.setSamplePrior(samplePrior);
        System.out.println(new Date());
        System.out.println("\nStarting FGES");
        long time4 = System.currentTimeMillis();
        fges = new FgesMb(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setOut(System.out);
        // fges.setHeuristicSpeedup(faithfulness);
        fges.setMaxIndegree(maxIndegree);
        fges.setCycleBound(-1);
        long timeb = System.currentTimeMillis();
        out.println("Time consructing BDeu score " + (time4 - time3) + " ms");
        out.println("Time for FGES-MB constructor " + (timeb - time4) + " ms");
        out.println();
    }
    int numSkipped = 0;
    for (int run = 0; run < numRuns; run++) {
        out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
        Node target = vars.get(RandomUtil.getInstance().nextInt(vars.size()));
        System.out.println("Target = " + target);
        long timea = System.currentTimeMillis();
        estPattern = fges.search(target);
        long timed = System.currentTimeMillis();
        elapsed = timed - timea;
        Set<Node> mb = new HashSet<>();
        mb.add(target);
        mb.addAll(pattern.getAdjacentNodes(target));
        for (Node child : pattern.getChildren(target)) {
            mb.addAll(pattern.getParents(child));
        }
        Graph trueMbGraph = pattern.subgraph(new ArrayList<>(mb));
        long timec = System.currentTimeMillis();
        out.println("Time for FGES-MB search " + (timec - timea) + " ms");
        out.println();
        System.out.println("Done with FGES");
        System.out.println(new Date());
        double[] comparison = new double[4];
        System.out.println("Counting misclassifications.");
        int[][] counts = GraphUtils.edgeMisclassificationCounts(trueMbGraph, estPattern, false);
        allCounts.add(counts);
        System.out.println(new Date());
        int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
        int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
        int trueArrow = counts[4][3];
        int sumTrueAdjacencies = 0;
        for (int i = 0; i < 7; i++) {
            for (int j = 0; j < 5; j++) {
                sumTrueAdjacencies += counts[i][j];
            }
        }
        int falsePositiveAdjacencies = 0;
        for (int j = 0; j < 5; j++) {
            falsePositiveAdjacencies += counts[7][j];
        }
        int falseNegativeAdjacencies = 0;
        for (int i = 0; i < 5; i++) {
            falseNegativeAdjacencies += counts[i][5];
        }
        comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
        comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
        comparison[2] = trueArrow / (double) sumCol;
        comparison[3] = trueArrow / (double) sumRow;
        // if (Double.isNaN(comparison[0]) || Double.isNaN(comparison[1]) || Double.isNaN(comparison[2]) ||
        // Double.isNaN(comparison[3])) {
        // run--;
        // numSkipped++;
        // continue;
        // }
        comparisons.add(comparison);
        out.println(GraphUtils.edgeMisclassifications(counts));
        out.println(precisionRecall(comparison));
        // printAverageConfusion("Average", allCounts);
        elapsedTimes.add(elapsed);
        out.println("\nElapsed: " + elapsed + " ms");
    }
    printAverageConfusion("Average", allCounts, new DecimalFormat("0.0"));
    printAveragePrecisionRecall(comparisons);
    out.println("Number of runs skipped because of undefined accuracies: " + numSkipped);
    printAverageStatistics(elapsedTimes, degrees);
    out.close();
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) DecimalFormat(java.text.DecimalFormat) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 15 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class PerformanceTests method testFges.

private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
    out.println(new Date());
    // RandomUtil.getInstance().setSeed(4828384343999L);
    double penaltyDiscount = 4.0;
    int maxIndegree = 5;
    boolean faithfulness = true;
    // RandomUtil.getInstance().setSeed(50304050454L);
    List<int[][]> allCounts = new ArrayList<>();
    List<double[]> comparisons = new ArrayList<>();
    List<Double> degrees = new ArrayList<>();
    List<Long> elapsedTimes = new ArrayList<>();
    if (continuous) {
        init(new File("fges.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Penalty discount = " + penaltyDiscount);
        out.println("Depth = " + maxIndegree);
        out.println();
    } else {
        init(new File("fges.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Sample prior = " + 1);
        out.println("Structure prior = " + 1);
        out.println("Depth = " + 1);
        out.println();
    }
    for (int run = 0; run < numRuns; run++) {
        out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
        System.out.println("Making dag");
        out.println(new Date());
        Graph dag = makeDag(numVars, edgeFactor);
        System.out.println(new Date());
        System.out.println("Calculating pattern for DAG");
        Graph pattern = SearchGraphUtils.patternForDag(dag);
        List<Node> vars = dag.getNodes();
        int[] tiers = new int[vars.size()];
        for (int i = 0; i < vars.size(); i++) {
            tiers[i] = i;
        }
        System.out.println("Graph done");
        long time1 = System.currentTimeMillis();
        out.println("Graph done");
        System.out.println(new Date());
        System.out.println("Starting simulation");
        Graph estPattern;
        long elapsed;
        if (continuous) {
            LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
            simulator.setVerbose(false);
            simulator.setOut(out);
            DataSet data = simulator.simulateDataFisher(numCases);
            System.out.println("Finishing simulation");
            System.out.println(new Date());
            long time2 = System.currentTimeMillis();
            out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
            System.out.println(new Date());
            System.out.println("Making covariance matrix");
            long time3 = System.currentTimeMillis();
            ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
            System.out.println("Covariance matrix done");
            out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
            SemBicScore score = new SemBicScore(cov);
            score.setPenaltyDiscount(penaltyDiscount);
            System.out.println(new Date());
            System.out.println("\nStarting FGES");
            long timea = System.currentTimeMillis();
            Fges fges = new Fges(score);
            // fges.setVerbose(false);
            fges.setNumPatternsToStore(0);
            fges.setOut(System.out);
            fges.setFaithfulnessAssumed(faithfulness);
            fges.setCycleBound(-1);
            long timeb = System.currentTimeMillis();
            estPattern = fges.search();
            long timec = System.currentTimeMillis();
            out.println("Time for FGES constructor " + (timeb - timea) + " ms");
            out.println("Time for FGES search " + (timec - timea) + " ms");
            out.println();
            out.flush();
            elapsed = timec - timea;
        } else {
            BayesPm pm = new BayesPm(dag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            DataSet data = im.simulateData(numCases, false, tiers);
            System.out.println("Finishing simulation");
            long time2 = System.currentTimeMillis();
            out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
            long time3 = System.currentTimeMillis();
            BDeuScore score = new BDeuScore(data);
            score.setStructurePrior(1);
            score.setSamplePrior(1);
            System.out.println(new Date());
            System.out.println("\nStarting FGES");
            long timea = System.currentTimeMillis();
            Fges fges = new Fges(score);
            // fges.setVerbose(false);
            fges.setNumPatternsToStore(0);
            fges.setOut(System.out);
            fges.setFaithfulnessAssumed(faithfulness);
            fges.setCycleBound(-1);
            long timeb = System.currentTimeMillis();
            estPattern = fges.search();
            long timec = System.currentTimeMillis();
            out.println("Time consructing BDeu score " + (timea - time3) + " ms");
            out.println("Time for FGES constructor " + (timeb - timea) + " ms");
            out.println("Time for FGES search " + (timec - timea) + " ms");
            out.println();
            elapsed = timec - timea;
        }
        System.out.println("Done with FGES");
        System.out.println(new Date());
        // System.out.println("Replacing nodes");d
        // 
        // estPattern = GraphUtils.replaceNodes(estPattern, dag.getNodes());
        // System.out.println("Calculating degree");
        // 
        // double degree = GraphUtils.degree(estPattern);
        // degrees.add(degree);
        // 
        // out.println("Degree out output graph = " + degree);
        double[] comparison = new double[4];
        // int adjFn = GraphUtils.countAdjErrors(pattern, estPattern);
        // int adjFp = GraphUtils.countAdjErrors(estPattern, pattern);
        // int trueAdj = pattern.getNumEdges();
        // 
        // comparison[0] = trueAdj / (double) (trueAdj + adjFp);
        // comparison[1] = trueAdj / (double) (trueAdj + adjFn);
        System.out.println("Counting misclassifications.");
        estPattern = GraphUtils.replaceNodes(estPattern, pattern.getNodes());
        int[][] counts = GraphUtils.edgeMisclassificationCounts(pattern, estPattern, false);
        allCounts.add(counts);
        System.out.println(new Date());
        int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
        int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
        int trueArrow = counts[4][3];
        int sumTrueAdjacencies = 0;
        for (int i = 0; i < 7; i++) {
            for (int j = 0; j < 5; j++) {
                sumTrueAdjacencies += counts[i][j];
            }
        }
        int falsePositiveAdjacencies = 0;
        for (int j = 0; j < 5; j++) {
            falsePositiveAdjacencies += counts[7][j];
        }
        int falseNegativeAdjacencies = 0;
        for (int i = 0; i < 5; i++) {
            falseNegativeAdjacencies += counts[i][5];
        }
        comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
        comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
        comparison[2] = trueArrow / (double) sumCol;
        comparison[3] = trueArrow / (double) sumRow;
        comparisons.add(comparison);
        out.println(GraphUtils.edgeMisclassifications(counts));
        out.println(precisionRecall(comparison));
        elapsedTimes.add(elapsed);
        out.println("\nElapsed: " + elapsed + " ms");
    }
    printAverageConfusion("Average", allCounts);
    printAveragePrecisionRecall(comparisons);
    printAverageStatistics(elapsedTimes, degrees);
    out.close();
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Aggregations

MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)26 BayesPm (edu.cmu.tetrad.bayes.BayesPm)23 BayesIm (edu.cmu.tetrad.bayes.BayesIm)20 Test (org.junit.Test)15 Graph (edu.cmu.tetrad.graph.Graph)9 Node (edu.cmu.tetrad.graph.Node)8 DataSet (edu.cmu.tetrad.data.DataSet)6 Dag (edu.cmu.tetrad.graph.Dag)6 ArrayList (java.util.ArrayList)6 LargeScaleSimulation (edu.cmu.tetrad.sem.LargeScaleSimulation)4 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)3 GraphNode (edu.cmu.tetrad.graph.GraphNode)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 ChiSquare (edu.cmu.tetrad.algcomparison.independence.ChiSquare)2 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)2 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)2 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2