use of edu.cmu.tetrad.regression.LogisticRegression in project tetrad by cmu-phil.
the class TestLogisticRegression method test1.
@Test
public void test1() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
System.out.println(graph);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateDataRecursive(1000, false);
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
Discretizer discretizer = new Discretizer(data);
discretizer.equalCounts(x1, 2);
DataSet d2 = discretizer.discretize();
LogisticRegression regression = new LogisticRegression(d2);
List<Node> regressors = new ArrayList<>();
regressors.add(x2);
regressors.add(x3);
regressors.add(x4);
regressors.add(x5);
DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
regression.regress(x1b, regressors);
System.out.println(regression);
}
use of edu.cmu.tetrad.regression.LogisticRegression in project tetrad by cmu-phil.
the class LogisticRegressionRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
outGraph = new EdgeListGraph();
if (regressorNames == null || regressorNames.isEmpty() || targetName == null) {
report = "Response and predictor variables not set.";
return;
}
if (regressorNames.contains(targetName)) {
report = "Response must not be a predictor.";
return;
}
DataSet regressorsDataSet = dataSets.get(getModelIndex()).copy();
Node target = regressorsDataSet.getVariable(targetName);
regressorsDataSet.removeColumn(target);
List<String> names = regressorsDataSet.getVariableNames();
// Get the list of regressors selected by the user
List<Node> regressorNodes = new ArrayList<>();
for (String s : regressorNames) {
regressorNodes.add(dataSets.get(getModelIndex()).getVariable(s));
}
// If the user selected none, use them all
if (regressorNames.size() > 0) {
for (String name1 : names) {
Node regressorVar = regressorsDataSet.getVariable(name1);
if (!regressorNames.contains(regressorVar.getName())) {
regressorsDataSet.removeColumn(regressorVar);
}
}
}
int ncases = regressorsDataSet.getNumRows();
int nvars = regressorsDataSet.getNumColumns();
double[][] regressors = new double[nvars][ncases];
for (int i = 0; i < nvars; i++) {
for (int j = 0; j < ncases; j++) {
regressors[i][j] = regressorsDataSet.getDouble(j, i);
}
}
LogisticRegression logRegression = new LogisticRegression(dataSets.get(getModelIndex()));
logRegression.setAlpha(alpha);
this.result = logRegression.regress((DiscreteVariable) target, regressorNodes);
}
Aggregations