Search in sources :

Example 81 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class IndTestRegression method getVariableNames.

/**
 * @return the list of variable varNames.
 */
public List<String> getVariableNames() {
    List<Node> variables = getVariables();
    List<String> variableNames = new ArrayList<>();
    for (Node variable : variables) {
        variableNames.add(variable.getName());
    }
    return variableNames;
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList)

Example 82 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class FgesMbRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    IKnowledge knowledge = (IKnowledge) getParams().get("knowledge", new Knowledge2());
    String targetName = getParams().getString("targetName", null);
    Object model = getDataModel();
    if (model == null && getSourceGraph() != null) {
        model = getSourceGraph();
    }
    if (model == null) {
        throw new RuntimeException("Data source is unspecified. You may need to double click all your data boxes, \n" + "then click Save, and then right click on them and select Propagate Downstream. \n" + "The issue is that we use a seed to simulate from IM's, so your data is not saved to \n" + "file when you save the session. It can, however, be recreated from the saved seed.");
    }
    Parameters params = getParams();
    Node target = null;
    if (model instanceof Graph) {
        GraphScore gesScore = new GraphScore((Graph) model);
        target = gesScore.getVariable(targetName);
        fges = new FgesMb(gesScore);
        fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
        fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1));
        fges.setVerbose(true);
    } else if (model instanceof DataSet) {
        DataSet dataSet = (DataSet) model;
        if (dataSet.isContinuous()) {
            SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model));
            target = score.getVariable(targetName);
            score.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4));
            fges = new FgesMb(score);
        } else if (dataSet.isDiscrete()) {
            // ((Parameters) getParameters()).getSamplePrior();
            double samplePrior = 1;
            // ((Parameters) getParameters()).getStructurePrior();
            double structurePrior = 1;
            BDeuScore score = new BDeuScore(dataSet);
            score.setSamplePrior(samplePrior);
            score.setStructurePrior(structurePrior);
            target = score.getVariable(targetName);
            fges = new FgesMb(score);
        } else {
            throw new IllegalStateException("Data set must either be continuous or discrete.");
        }
    } else if (model instanceof ICovarianceMatrix) {
        SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model);
        gesScore.setPenaltyDiscount(params.getDouble("alpha", 0.001));
        gesScore.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4));
        target = gesScore.getVariable(targetName);
        fges = new FgesMb(gesScore);
    } else if (model instanceof DataModelList) {
        DataModelList list = (DataModelList) model;
        for (DataModel dataModel : list) {
            if (!(dataModel instanceof DataSet || dataModel instanceof ICovarianceMatrix)) {
                throw new IllegalArgumentException("Need a combination of all continuous data sets or " + "covariance matrices, or else all discrete data sets, or else a single initialGraph.");
            }
        }
        if (allContinuous(list)) {
            double penalty = getParams().getDouble("penaltyDiscount", 4);
            if (params.getBoolean("firstNontriangular", false)) {
                SemBicScoreImages fgesScore = new SemBicScoreImages(list);
                fgesScore.setPenaltyDiscount(penalty);
                target = fgesScore.getVariable(targetName);
                fges = new FgesMb(fgesScore);
            } else {
                SemBicScoreImages fgesScore = new SemBicScoreImages(list);
                fgesScore.setPenaltyDiscount(penalty);
                target = fgesScore.getVariable(targetName);
                fges = new FgesMb(fgesScore);
            }
        } else if (allDiscrete(list)) {
            double structurePrior = getParams().getDouble("structurePrior", 1);
            double samplePrior = getParams().getDouble("samplePrior", 1);
            BdeuScoreImages fgesScore = new BdeuScoreImages(list);
            fgesScore.setSamplePrior(samplePrior);
            fgesScore.setStructurePrior(structurePrior);
            target = fgesScore.getVariable(targetName);
            if (params.getBoolean("firstNontriangular", false)) {
                fges = new FgesMb(fgesScore);
            } else {
                fges = new FgesMb(fgesScore);
            }
        } else {
            throw new IllegalArgumentException("Data must be either all discrete or all continuous.");
        }
    } else {
        System.out.println("No viable input.");
    }
    // Graph searchGraph;
    // 
    // if (true) {
    // DataModel dataModel = getDataModelList().getSelectedModel();
    // ICovarianceMatrix cov;
    // Node target;
    // FgesMb fges;
    // 
    // if (dataModel instanceof DataSet) {
    // DataSet dataSet = (DataSet) dataModel;
    // target = dataSet.getVariable(targetName);
    // 
    // if (dataSet.isContinuous()) {
    // SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) dataModel),
    // getParameters().getDepErrorsAlpha());
    // fges = new FgesMb(gesScore, target);
    // } else if (dataSet.isDiscrete()) {
    // double structurePrior = 1;
    // double samplePrior = getParameters().getDepErrorsAlpha();
    // BDeuScore score = new BDeuScore(dataSet);
    // score.setSamplePrior(samplePrior);
    // score.setStructurePrior(structurePrior);
    // fges = new FgesMb(score, target);
    // } else {
    // throw new IllegalStateException("Data set must either be continuous or discrete.");
    // }
    // } else if (dataModel instanceof ICovarianceMatrix) {
    // cov = (ICovarianceMatrix) dataModel;
    // SemBicScore score = new SemBicScore(cov,
    // getParameters().getDepErrorsAlpha());
    // target = cov.getVariable(targetName);
    // fges = new FgesMb(score, target);
    // } else {
    // throw new IllegalArgumentException("Expecting a data set or a covariance matrix.");
    // }
    // 
    // fges.setVerbose(true);
    // fges.setHeuristicSpeedup(true);
    // searchGraph = fges.search();
    // } else {
    // Node target = getIndependenceTest().getVariable(targetName);
    // System.out.println("Target = " + target);
    // 
    // int depth = getParameters().getMaxDegree();
    // 
    // ScoredIndTest fgesScore = new ScoredIndTest(getIndependenceTest());
    // fgesScore.setParameter1(getParameters().getDepErrorsAlpha());
    // FgesMb search = new FgesMb(fgesScore, target);
    // search.setKnowledge(knowledge);
    // search.setMaxDegree(depth);
    // search.setVerbose(true);
    // search.setHeuristicSpeedup(true);
    // searchGraph = search.search();
    // }
    // if (getSourceGraph() != null) {
    // GraphUtils.arrangeBySourceGraph(searchGraph, getSourceGraph());
    // } else if (knowledge.isDefaultToKnowledgeLayout()) {
    // SearchGraphUtils.arrangeByKnowledgeTiers(searchGraph, knowledge);
    // } else {
    // GraphUtils.circleLayout(searchGraph, 200, 200, 150);
    // }
    // fges.setInitialGraph(initialGraph);
    fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
    fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1));
    fges.setVerbose(true);
    // fges.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed());
    fges.setMaxIndegree(params.getInt("depth", -1));
    Graph graph = fges.search(target);
    if (getSourceGraph() != null) {
        GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else if (((IKnowledge) getParams().get("knowledge", new Knowledge2())).isDefaultToKnowledgeLayout()) {
        SearchGraphUtils.arrangeByKnowledgeTiers(graph, (IKnowledge) getParams().get("knowledge", new Knowledge2()));
    } else {
        GraphUtils.circleLayout(graph, 200, 200, 150);
    }
    this.topGraphs = new ArrayList<>(fges.getTopGraphs());
    if (topGraphs.isEmpty()) {
        topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN));
    }
    setIndex(topGraphs.size() - 1);
    setResultGraph(graph);
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) Node(edu.cmu.tetrad.graph.Node) Graph(edu.cmu.tetrad.graph.Graph)

Example 83 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class FofcRunner method execute.

// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    Graph searchGraph;
    FindOneFactorClusters fofc;
    Object source = getData();
    TestType tetradTestType = (TestType) getParams().get("tetradTestType", TestType.TETRAD_WISHART);
    if (tetradTestType == null || (!(tetradTestType == TestType.TETRAD_DELTA || tetradTestType == TestType.TETRAD_WISHART))) {
        tetradTestType = TestType.TETRAD_DELTA;
        getParams().set("tetradTestType", tetradTestType);
    }
    FindOneFactorClusters.Algorithm algorithm = (FindOneFactorClusters.Algorithm) getParams().get("fofcAlgorithm", FindOneFactorClusters.Algorithm.GAP);
    if (source instanceof DataSet) {
        fofc = new FindOneFactorClusters((DataSet) source, tetradTestType, algorithm, getParams().getDouble("alpha", 0.001));
        searchGraph = fofc.search();
    } else if (source instanceof CovarianceMatrix) {
        fofc = new FindOneFactorClusters((CovarianceMatrix) source, tetradTestType, algorithm, getParams().getDouble("alpha", 0.001));
        searchGraph = fofc.search();
    } else {
        throw new IllegalArgumentException("Unrecognized data type.");
    }
    if (semIm != null) {
        List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
        List<String> variableNames = ReidentifyVariables.reidentifyVariables2(partition, trueGraph, (DataSet) getData());
        rename(searchGraph, partition, variableNames);
    // searchGraph = reidentifyVariables2(searchGraph, semIm);
    } else if (trueGraph != null) {
        List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
        List<String> variableNames = ReidentifyVariables.reidentifyVariables1(partition, trueGraph);
        rename(searchGraph, partition, variableNames);
    // searchGraph = reidentifyVariables(searchGraph, trueGraph);
    }
    System.out.println("Search Graph " + searchGraph);
    try {
        Graph graph = new MarshalledObject<>(searchGraph).get();
        GraphUtils.circleLayout(graph, 200, 200, 150);
        GraphUtils.fruchtermanReingoldLayout(graph);
        setResultGraph(graph);
        setClusters(MimUtils.convertToClusters(graph, getData().getVariables()));
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e);
    }
}
Also used : Node(edu.cmu.tetrad.graph.Node) TestType(edu.cmu.tetrad.search.TestType) Graph(edu.cmu.tetrad.graph.Graph) FindOneFactorClusters(edu.cmu.tetrad.search.FindOneFactorClusters) MarshalledObject(java.rmi.MarshalledObject) ArrayList(java.util.ArrayList) List(java.util.List)

Example 84 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class FofcRunner method getVariables.

public List<Node> getVariables() {
    List<Node> latents = new ArrayList<>();
    for (String name : getVariableNames()) {
        Node node = new ContinuousVariable(name);
        node.setNodeType(NodeType.LATENT);
        latents.add(node);
    }
    return latents;
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList)

Example 85 with Node

use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.

the class FtfcRunner method execute.

// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    Graph searchGraph;
    FindTwoFactorClusters ftfc;
    Object source = getData();
    TestType tetradTestType = (TestType) getParams().get("tetradTestType", TestType.TETRAD_WISHART);
    if (tetradTestType == null || (!(tetradTestType == TestType.TETRAD_DELTA || tetradTestType == TestType.TETRAD_WISHART))) {
        tetradTestType = TestType.TETRAD_DELTA;
        getParams().set("tetradTestType", tetradTestType);
    }
    FindTwoFactorClusters.Algorithm algorithm = (FindTwoFactorClusters.Algorithm) getParams().get("ftfcAlgorithm", FindTwoFactorClusters.Algorithm.GAP);
    if (source instanceof DataSet) {
        ftfc = new FindTwoFactorClusters((DataSet) source, algorithm, getParams().getDouble("alpha", 0.001));
        ftfc.setVerbose(true);
        searchGraph = ftfc.search();
    } else if (source instanceof CovarianceMatrix) {
        ftfc = new FindTwoFactorClusters((CovarianceMatrix) source, algorithm, getParams().getDouble("alpha", 0.001));
        ftfc.setVerbose(true);
        searchGraph = ftfc.search();
    } else {
        throw new IllegalArgumentException("Unrecognized data type.");
    }
    if (semIm != null) {
        List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
        List<String> variableNames = ReidentifyVariables.reidentifyVariables2(partition, trueGraph, (DataSet) getData());
        rename(searchGraph, partition, variableNames);
    // searchGraph = reidentifyVariables2(searchGraph, semIm);
    } else if (trueGraph != null) {
        List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
        List<String> variableNames = ReidentifyVariables.reidentifyVariables1(partition, trueGraph);
        rename(searchGraph, partition, variableNames);
    // searchGraph = reidentifyVariables(searchGraph, trueGraph);
    }
    System.out.println("Search Graph " + searchGraph);
    try {
        Graph graph = new MarshalledObject<>(searchGraph).get();
        GraphUtils.circleLayout(graph, 200, 200, 150);
        GraphUtils.fruchtermanReingoldLayout(graph);
        setResultGraph(graph);
        setClusters(MimUtils.convertToClusters(graph, getData().getVariables()));
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e);
    }
}
Also used : Node(edu.cmu.tetrad.graph.Node) TestType(edu.cmu.tetrad.search.TestType) Graph(edu.cmu.tetrad.graph.Graph) MarshalledObject(java.rmi.MarshalledObject) ArrayList(java.util.ArrayList) List(java.util.List) FindTwoFactorClusters(edu.cmu.tetrad.search.FindTwoFactorClusters)

Aggregations

Node (edu.cmu.tetrad.graph.Node)674 ArrayList (java.util.ArrayList)129 Graph (edu.cmu.tetrad.graph.Graph)106 GraphNode (edu.cmu.tetrad.graph.GraphNode)64 DataSet (edu.cmu.tetrad.data.DataSet)59 LinkedList (java.util.LinkedList)55 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)48 Test (org.junit.Test)48 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)46 List (java.util.List)45 Dag (edu.cmu.tetrad.graph.Dag)41 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)41 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)40 ChoiceGenerator (edu.cmu.tetrad.util.ChoiceGenerator)37 Endpoint (edu.cmu.tetrad.graph.Endpoint)29 DisplayNode (edu.cmu.tetradapp.workbench.DisplayNode)26 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)25 Edge (edu.cmu.tetrad.graph.Edge)23 SemIm (edu.cmu.tetrad.sem.SemIm)19 DepthChoiceGenerator (edu.cmu.tetrad.util.DepthChoiceGenerator)19