Search in sources :

Example 1 with LogisticRegression

use of edu.cmu.tetrad.regression.LogisticRegression in project tetrad by cmu-phil.

the class TestLogisticRegression method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
    System.out.println(graph);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateDataRecursive(1000, false);
    Node x1 = data.getVariable("X1");
    Node x2 = data.getVariable("X2");
    Node x3 = data.getVariable("X3");
    Node x4 = data.getVariable("X4");
    Node x5 = data.getVariable("X5");
    Discretizer discretizer = new Discretizer(data);
    discretizer.equalCounts(x1, 2);
    DataSet d2 = discretizer.discretize();
    LogisticRegression regression = new LogisticRegression(d2);
    List<Node> regressors = new ArrayList<>();
    regressors.add(x2);
    regressors.add(x3);
    regressors.add(x4);
    regressors.add(x5);
    DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
    regression.regress(x1b, regressors);
    System.out.println(regression);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 2 with LogisticRegression

use of edu.cmu.tetrad.regression.LogisticRegression in project tetrad by cmu-phil.

the class LogisticRegressionRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    outGraph = new EdgeListGraph();
    if (regressorNames == null || regressorNames.isEmpty() || targetName == null) {
        report = "Response and predictor variables not set.";
        return;
    }
    if (regressorNames.contains(targetName)) {
        report = "Response must not be a predictor.";
        return;
    }
    DataSet regressorsDataSet = dataSets.get(getModelIndex()).copy();
    Node target = regressorsDataSet.getVariable(targetName);
    regressorsDataSet.removeColumn(target);
    List<String> names = regressorsDataSet.getVariableNames();
    // Get the list of regressors selected by the user
    List<Node> regressorNodes = new ArrayList<>();
    for (String s : regressorNames) {
        regressorNodes.add(dataSets.get(getModelIndex()).getVariable(s));
    }
    // If the user selected none, use them all
    if (regressorNames.size() > 0) {
        for (String name1 : names) {
            Node regressorVar = regressorsDataSet.getVariable(name1);
            if (!regressorNames.contains(regressorVar.getName())) {
                regressorsDataSet.removeColumn(regressorVar);
            }
        }
    }
    int ncases = regressorsDataSet.getNumRows();
    int nvars = regressorsDataSet.getNumColumns();
    double[][] regressors = new double[nvars][ncases];
    for (int i = 0; i < nvars; i++) {
        for (int j = 0; j < ncases; j++) {
            regressors[i][j] = regressorsDataSet.getDouble(j, i);
        }
    }
    LogisticRegression logRegression = new LogisticRegression(dataSets.get(getModelIndex()));
    logRegression.setAlpha(alpha);
    this.result = logRegression.regress((DiscreteVariable) target, regressorNodes);
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ArrayList(java.util.ArrayList) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression)

Aggregations

DataSet (edu.cmu.tetrad.data.DataSet)2 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)2 Node (edu.cmu.tetrad.graph.Node)2 LogisticRegression (edu.cmu.tetrad.regression.LogisticRegression)2 ArrayList (java.util.ArrayList)2 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)1 Discretizer (edu.cmu.tetrad.data.Discretizer)1 Dag (edu.cmu.tetrad.graph.Dag)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1 Graph (edu.cmu.tetrad.graph.Graph)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1 Test (org.junit.Test)1