use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class SemEstimatorWrapper method serializableInstance.
// public SemEstimatorWrapper(DataWrapper dataWrapper,
// SemPmWrapper semPmWrapper,
// SemImWrapper semImWrapper,
// Parameters params) {
// if (dataWrapper == null) {
// throw new NullPointerException();
// }
//
// if (semPmWrapper == null) {
// throw new NullPointerException();
// }
//
// if (semImWrapper == null) {
// throw new NullPointerException();
// }
//
// DataSet dataSet =
// (DataSet) dataWrapper.getSelectedDataModel();
// SemPm semPm = semPmWrapper.getSemPm();
// SemIm semIm = semImWrapper.getSemIm();
//
// this.semEstimator = new SemEstimator(dataSet, semPm, getOptimizer());
// if (!degreesOfFreedomCheck(semPm)) return;
// this.semEstimator.setTrueSemIm(semIm);
// this.semEstimator.setNumRestarts(getParams().getInt("numRestarts", 1));
// this.semEstimator.estimate();
//
// this.params = params;
//
// log();
// }
/**
* Generates a simple exemplar of this class to test serialization.
*
* @see TetradSerializableUtils
*/
public static SemEstimatorWrapper serializableInstance() {
List<Node> variables = new LinkedList<>();
ContinuousVariable x = new ContinuousVariable("X");
variables.add(x);
DataSet dataSet = new ColtDataSet(10, variables);
for (int i = 0; i < dataSet.getNumRows(); i++) {
for (int j = 0; j < dataSet.getNumColumns(); j++) {
dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
}
}
Dag dag = new Dag();
dag.addNode(x);
SemPm pm = new SemPm(dag);
Parameters params1 = new Parameters();
return new SemEstimatorWrapper(dataSet, pm, params1);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class CovMatrixTable method setVariableName.
private void setVariableName(int index, String name) {
List variables = getCovMatrix().getVariables();
for (int i = 0; i < variables.size(); i++) {
ContinuousVariable _variable = (ContinuousVariable) variables.get(i);
if (name.equals(_variable.getName())) {
return;
}
}
ContinuousVariable variable = (ContinuousVariable) variables.get(index);
variable.setName(name);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class IndTestMultinomialLogisticRegressionWald method isIndependentRegression.
private boolean isIndependentRegression(Node x, Node y, List<Node> z) {
if (!variablesPerNode.containsKey(x)) {
throw new IllegalArgumentException("Unrecogized node: " + x);
}
if (!variablesPerNode.containsKey(y)) {
throw new IllegalArgumentException("Unrecogized node: " + y);
}
for (Node node : z) {
if (!variablesPerNode.containsKey(node)) {
throw new IllegalArgumentException("Unrecogized node: " + node);
}
}
List<Node> regressors = new ArrayList<>();
if (y instanceof ContinuousVariable) {
regressors.add(internalData.getVariable(y.getName()));
} else {
regressors.addAll(variablesPerNode.get(y));
}
for (Node _z : z) {
regressors.addAll(variablesPerNode.get(_z));
}
int[] _rows = getNonMissingRows(x, y, z);
regression.setRows(_rows);
RegressionResult result;
try {
result = regression.regress(x, regressors);
} catch (Exception e) {
return false;
}
double p = 1;
if (y instanceof ContinuousVariable) {
p = result.getP()[1];
} else {
for (int i = 0; i < variablesPerNode.get(y).size(); i++) {
double val = result.getP()[1 + i];
if (val < p)
p = val;
}
}
this.lastP = p;
boolean indep = p > alpha;
if (indep) {
TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, p));
} else {
TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, p));
}
return indep;
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class MGM method runTests1.
private static void runTests1() {
try {
// DoubleMatrix2D xIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_C.txt"));
// DoubleMatrix2D yIn = DoubleFactory2D.dense.make(loadDataSelect("/Users/ajsedgewick/tetrad/test_data", "med_test_D.txt"));
// String path = MGM.class.getResource("test_data").getPath();
String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data";
System.out.println(path);
DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_C.txt").getDoubleData().toArray());
DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim(path, "med_test_D.txt").getDoubleData().toArray());
int[] L = new int[24];
Node[] vars = new Node[48];
for (int i = 0; i < 24; i++) {
L[i] = 2;
vars[i] = new ContinuousVariable("X" + i);
vars[i + 24] = new DiscreteVariable("Y" + i);
}
double lam = .2;
MGM model = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
MGM model2 = new MGM(xIn, yIn, new ArrayList<>(Arrays.asList(vars)), L, new double[] { lam, lam, lam });
System.out.println("Weights: " + Arrays.toString(model.weights.toArray()));
DoubleMatrix2D test = xIn.copy();
DoubleMatrix2D test2 = xIn.copy();
long t = System.currentTimeMillis();
for (int i = 0; i < 50000; i++) {
test2 = xIn.copy();
test.assign(test2);
}
System.out.println("assign Time: " + (System.currentTimeMillis() - t));
t = System.currentTimeMillis();
double[][] xArr = xIn.toArray();
for (int i = 0; i < 50000; i++) {
if (Thread.currentThread().isInterrupted()) {
break;
}
// test = DoubleFactory2D.dense.make(xArr);
test2 = xIn.copy();
test = test2;
}
System.out.println("equals Time: " + (System.currentTimeMillis() - t));
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
t = System.currentTimeMillis();
model.learnEdges(700);
// model.learn(1e-7, 700);
System.out.println("Orig Time: " + (System.currentTimeMillis() - t));
System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
} catch (IOException ex) {
ex.printStackTrace();
}
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class ADTreeTest method main.
public static void main(String[] args) throws Exception {
int columns = 40;
int numEdges = 40;
int rows = 500;
List<Node> variables = new ArrayList<>();
List<String> varNames = new ArrayList<>();
for (int i = 0; i < columns; i++) {
final String name = "X" + (i + 1);
varNames.add(name);
variables.add(new ContinuousVariable(name));
}
Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
BayesPm pm = new BayesPm(graph);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(rows, false);
// This implementation uses a DataTable to represent the data
// The first type parameter is the type for the variables
// The second type parameter is the type for the values of the variables
DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
for (int i = 0; i < rows; i++) {
ArrayList<Short> intArray = new ArrayList<>();
for (int j = 0; j < columns; j++) {
intArray.add((short) data.getInt(i, j));
}
dataTable.addRow(intArray);
}
// create the tree
long start = System.currentTimeMillis();
ADTree<Node, Short> adTree = new ADTree<>(dataTable);
System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
// the query is an arbitrary map of vars and their values
TreeMap<Node, Short> query = new TreeMap<>();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
query.clear();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X2"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
query.put(node(pm, "X10"), (short) 1);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
Aggregations