Search in sources :

Example 36 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class GesMe method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        // dataSet = DataUtils.center((DataSet) dataSet);
        CovarianceMatrix covarianceMatrix = new CovarianceMatrix((DataSet) dataSet);
        edu.cmu.tetrad.search.FactorAnalysis analysis = new edu.cmu.tetrad.search.FactorAnalysis(covarianceMatrix);
        analysis.setThreshold(parameters.getDouble("convergenceThreshold"));
        analysis.setNumFactors(parameters.getInt("numFactors"));
        // analysis.setNumFactors(((DataSet) dataSet).getNumColumns());
        TetradMatrix unrotated = analysis.successiveResidual();
        TetradMatrix rotated = analysis.successiveFactorVarimax(unrotated);
        if (parameters.getBoolean("verbose")) {
            NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
            String output = "Unrotated Factor Loading Matrix:\n";
            output += tableString(unrotated, nf, Double.POSITIVE_INFINITY);
            if (unrotated.columns() != 1) {
                output += "\n\nRotated Matrix (using sequential varimax):\n";
                output += tableString(rotated, nf, parameters.getDouble("fa_threshold"));
            }
            System.out.println(output);
            TetradLogger.getInstance().forceLogMessage(output);
        }
        TetradMatrix L;
        if (parameters.getBoolean("useVarimax")) {
            L = rotated;
        } else {
            L = unrotated;
        }
        TetradMatrix residual = analysis.getResidual();
        ICovarianceMatrix covFa = new CovarianceMatrix(covarianceMatrix.getVariables(), L.times(L.transpose()), covarianceMatrix.getSampleSize());
        System.out.println(covFa);
        final double[] vars = covarianceMatrix.getMatrix().diag().toArray();
        List<Integer> indices = new ArrayList<>();
        for (int i = 0; i < vars.length; i++) {
            indices.add(i);
        }
        Collections.sort(indices, new Comparator<Integer>() {

            @Override
            public int compare(Integer o1, Integer o2) {
                return -Double.compare(vars[o1], vars[o2]);
            }
        });
        NumberFormat nf = new DecimalFormat("0.000");
        for (int i = 0; i < indices.size(); i++) {
            System.out.println(nf.format(vars[indices.get(i)]) + " ");
        }
        System.out.println();
        int n = vars.length;
        int cutoff = (int) (n * ((sqrt(8 * n + 1) - 1) / (2 * n)));
        List<Node> nodes = covarianceMatrix.getVariables();
        List<Node> leaves = new ArrayList<>();
        for (int i = 0; i < cutoff; i++) {
            leaves.add(nodes.get(indices.get(i)));
        }
        IKnowledge knowledge2 = new Knowledge2();
        for (Node v : nodes) {
            if (leaves.contains(v)) {
                knowledge2.addToTier(2, v.getName());
            } else {
                knowledge2.addToTier(1, v.getName());
            }
        }
        knowledge2.setTierForbiddenWithin(2, true);
        System.out.println("knowledge2 = " + knowledge2);
        Score score = this.score.getScore(covFa, parameters);
        edu.cmu.tetrad.search.Fges2 search = new edu.cmu.tetrad.search.Fges2(score);
        search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed"));
        if (parameters.getBoolean("enforceMinimumLeafNodes")) {
            search.setKnowledge(knowledge2);
        }
        search.setVerbose(parameters.getBoolean("verbose"));
        search.setMaxDegree(parameters.getInt("maxDegree"));
        search.setSymmetricFirstStep(parameters.getBoolean("symmetricFirstStep"));
        Object obj = parameters.get("printStream");
        if (obj instanceof PrintStream) {
            search.setOut((PrintStream) obj);
        }
        if (parameters.getBoolean("verbose")) {
            // NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
            String output = "Unrotated Factor Loading Matrix:\n";
            double threshold = parameters.getDouble("fa_threshold");
            output += tableString(L, nf, Double.POSITIVE_INFINITY);
            if (L.columns() != 1) {
                output += "\n\nL:\n";
                output += tableString(L, nf, threshold);
            }
            System.out.println(output);
            TetradLogger.getInstance().forceLogMessage(output);
        }
        System.out.println("residual = " + residual);
        return search.search();
    } else {
        GesMe algorithm = new GesMe(compareToTrue);
        if (initialGraph != null) {
            algorithm.setInitialGraph(initialGraph);
        }
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : DecimalFormat(java.text.DecimalFormat) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) PrintStream(java.io.PrintStream) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) Score(edu.cmu.tetrad.search.Score) NumberFormat(java.text.NumberFormat)

Example 37 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class StARS2 method getD.

// static class FittingFunction implements MultivariateFunction {
// 
// private final List<DataSet> samples;
// private final int numSamples;
// private final Algorithm algorithm;
// private final double cutoff;
// private final double low;
// private final double high;
// private final String paramName;
// private final DataSet _dataSet;
// private Parameters params;
// Map<Double, Double> archive = new HashMap<>();
// 
// /**
// * Constructs a new CoefFittingFunction for the given Sem.
// */
// public FittingFunction(List<DataSet> samples, int numSamples, Parameters params, Algorithm algorithm,
// double cutoff, double low, double high, String paramName, DataSet _dataSet,
// Map<Double, Double> archive
// ) {
// this.samples = samples;
// this.numSamples = numSamples;
// this.params = params;
// this.algorithm = algorithm;
// this.cutoff = cutoff;
// this.low = low;
// this.high = high;
// this.paramName = paramName;
// this._dataSet = _dataSet;
// this.archive = archive;
// }
// 
// /**
// * Computes the maximum likelihood function value for the given
// * parameter values as given by the optimizer. These values are mapped to
// * parameter values.
// */
// 
// @Override
// public double value(double[] parameters) {
// double paramValue = parameters[0];
// paramValue = getValue(paramValue, params);
// //            paramValue = Math.round(paramValue * 10.0) / 10.0;
// if (paramValue < low) return -10000;
// if (paramValue > high) return -10000;
// if (archive.containsKey(paramValue)) {
// return archive.get(paramValue);
// }
// double D = getD(params, paramName, paramValue, samples, numSamples, algorithm, archive);
// if (D > cutoff) return -10000;
// archive.put(paramValue, D);
// return D;
// }
// }
private static double getD(Parameters params, String paramName, double paramValue, List<DataSet> boostraps, int numBootstraps, Algorithm algorithm, Map<Double, Double> archive) {
    params.set(paramName, paramValue);
    List<Graph> graphs = new ArrayList<>();
    for (DataSet d : boostraps) {
        Graph e = GraphUtils.undirectedGraph(algorithm.search(d, params));
        e = GraphUtils.replaceNodes(e, boostraps.get(0).getVariables());
        graphs.add(e);
    }
    int p = boostraps.get(0).getNumColumns();
    List<Node> nodes = graphs.get(0).getNodes();
    double D = 0.0;
    for (int i = 0; i < p; i++) {
        for (int j = i + 1; j < p; j++) {
            double theta = 0.0;
            for (int k = 0; k < numBootstraps; k++) {
                boolean adj = graphs.get(k).isAdjacentTo(nodes.get(i), nodes.get(j));
                theta += adj ? 1.0 : 0.0;
            }
            theta /= numBootstraps;
            double xsi = 2 * theta * (1.0 - theta);
            D += xsi;
        }
    }
    D /= (double) (p * (p - 1) / 2);
    System.out.println(paramName + " = " + paramValue + " D = " + D);
    return D;
}
Also used : TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList)

Example 38 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class StARS method getD.

private static double getD(Parameters params, String paramName, double paramValue, final List<DataSet> samples, Algorithm algorithm) {
    params.set(paramName, paramValue);
    List<Graph> graphs = new ArrayList<>();
    // for (DataSet d : samples) {
    // Graph e = GraphUtils.undirectedGraph(algorithm.search(d, params));
    // e = GraphUtils.replaceNodes(e, samples.get(0).getVariables());
    // graphs.add(e);
    // }
    final ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
    class StabilityAction extends RecursiveAction {

        private int chunk;

        private int from;

        private int to;

        private StabilityAction(int chunk, int from, int to) {
            this.chunk = chunk;
            this.from = from;
            this.to = to;
        }

        @Override
        protected void compute() {
            if (to - from <= chunk) {
                for (int s = from; s < to; s++) {
                    Graph e = algorithm.search(samples.get(s), params);
                    e = GraphUtils.replaceNodes(e, samples.get(0).getVariables());
                    graphs.add(e);
                }
            } else {
                final int mid = (to + from) / 2;
                StabilityAction left = new StabilityAction(chunk, from, mid);
                StabilityAction right = new StabilityAction(chunk, mid, to);
                left.fork();
                right.compute();
                left.join();
            }
        }
    }
    final int chunk = 1;
    pool.invoke(new StabilityAction(chunk, 0, samples.size()));
    int p = samples.get(0).getNumColumns();
    List<Node> nodes = graphs.get(0).getNodes();
    double D = 0.0;
    int count = 0;
    for (int i = 0; i < p; i++) {
        for (int j = i + 1; j < p; j++) {
            double theta = 0.0;
            Node x = nodes.get(i);
            Node y = nodes.get(j);
            for (int k = 0; k < graphs.size(); k++) {
                if (graphs.get(k).isAdjacentTo(x, y)) {
                    theta += 1.0;
                }
            }
            theta /= graphs.size();
            double xsi = 2 * theta * (1.0 - theta);
            // if (xsi != 0){
            D += xsi;
            count++;
        // }
        }
    }
    D /= (double) count;
    return D;
}
Also used : TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) RecursiveAction(java.util.concurrent.RecursiveAction) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Example 39 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class Mgm method search.

@Override
public Graph search(DataModel ds, Parameters parameters) {
    // Notify the user that you need at least one continuous and one discrete variable to run MGM
    List<Node> variables = ds.getVariables();
    boolean hasContinuous = false;
    boolean hasDiscrete = false;
    for (Node node : variables) {
        if (node instanceof ContinuousVariable) {
            hasContinuous = true;
        }
        if (node instanceof DiscreteVariable) {
            hasDiscrete = true;
        }
    }
    if (!hasContinuous || !hasDiscrete) {
        throw new IllegalArgumentException("You need at least one continuous and one discrete variable to run MGM.");
    }
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        DataSet _ds = DataUtils.getMixedDataSet(ds);
        double mgmParam1 = parameters.getDouble("mgmParam1");
        double mgmParam2 = parameters.getDouble("mgmParam2");
        double mgmParam3 = parameters.getDouble("mgmParam3");
        double[] lambda = { mgmParam1, mgmParam2, mgmParam3 };
        MGM m = new MGM(_ds, lambda);
        return m.search();
    } else {
        Mgm algorithm = new Mgm();
        DataSet data = (DataSet) ds;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) MGM(edu.pitt.csb.mgm.MGM)

Example 40 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class MixedFgesDiscretingContinuousVariables method search.

public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        Discretizer discretizer = new Discretizer(DataUtils.getContinuousDataSet(dataSet));
        List<Node> nodes = dataSet.getVariables();
        for (Node node : nodes) {
            if (node instanceof ContinuousVariable) {
                discretizer.equalIntervals(node, parameters.getInt("numCategories"));
            }
        }
        dataSet = discretizer.discretize();
        DataSet _dataSet = DataUtils.getDiscreteDataSet(dataSet);
        Fges fges = new Fges(score.getScore(_dataSet, parameters));
        Graph p = fges.search();
        return convertBack(_dataSet, p);
    } else {
        MixedFgesDiscretingContinuousVariables algorithm = new MixedFgesDiscretingContinuousVariables(score);
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) Node(edu.cmu.tetrad.graph.Node) Fges(edu.cmu.tetrad.search.Fges)

Aggregations

Node (edu.cmu.tetrad.graph.Node)674 ArrayList (java.util.ArrayList)129 Graph (edu.cmu.tetrad.graph.Graph)106 GraphNode (edu.cmu.tetrad.graph.GraphNode)64 DataSet (edu.cmu.tetrad.data.DataSet)59 LinkedList (java.util.LinkedList)55 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)48 Test (org.junit.Test)48 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)46 List (java.util.List)45 Dag (edu.cmu.tetrad.graph.Dag)41 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)41 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)40 ChoiceGenerator (edu.cmu.tetrad.util.ChoiceGenerator)37 Endpoint (edu.cmu.tetrad.graph.Endpoint)29 DisplayNode (edu.cmu.tetradapp.workbench.DisplayNode)26 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)25 Edge (edu.cmu.tetrad.graph.Edge)23 SemIm (edu.cmu.tetrad.sem.SemIm)19 DepthChoiceGenerator (edu.cmu.tetrad.util.DepthChoiceGenerator)19