use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class DataConvertUtils method toMixedDataBox.
public static DataModel toMixedDataBox(MixedTabularDataset mixedTabularDataset) {
int numOfRows = mixedTabularDataset.getNumOfRows();
MixedVarInfo[] mixedVarInfos = mixedTabularDataset.getMixedVarInfos();
double[][] continuousData = mixedTabularDataset.getContinuousData();
int[][] discreteData = mixedTabularDataset.getDiscreteData();
List<Node> nodes = new LinkedList<>();
for (MixedVarInfo mixedVarInfo : mixedVarInfos) {
if (mixedVarInfo.isContinuous()) {
nodes.add(new ContinuousVariable(mixedVarInfo.getName()));
} else {
nodes.add(new DiscreteVariable(mixedVarInfo.getName(), mixedVarInfo.getCategories()));
}
}
return new BoxDataSet(new MixedDataBox(nodes, numOfRows, continuousData, discreteData), nodes);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class GdistanceTest method main.
public static void main(String... args) {
// first generate a couple random graphs
int numVars = 16;
int numEdges = 16;
List<Node> vars = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
Graph testdag1 = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
Graph testdag2 = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
// System.out.println(testdag1);
// load the location map
String workingDirectory = System.getProperty("user.dir");
System.out.println(workingDirectory);
Path mapPath = Paths.get("locationMap.txt");
System.out.println(mapPath);
TabularDataReader dataReaderMap = new ContinuousTabularDataFileReader(mapPath.toFile(), Delimiter.COMMA);
try {
DataSet locationMap = (DataSet) DataConvertUtils.toDataModel(dataReaderMap.readInData());
// System.out.println(locationMap);
// then compare their distance
double xdist = 2.4;
double ydist = 2.4;
double zdist = 2;
Gdistance gdist = new Gdistance(locationMap, xdist, ydist, zdist);
List<Double> output = gdist.distances(testdag1, testdag2);
System.out.println(output);
PrintWriter writer = new PrintWriter("Gdistances.txt", "UTF-8");
writer.println(output);
writer.close();
} catch (Exception IOException) {
IOException.printStackTrace();
}
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class HsimRobustCompare method run.
// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
// public static void main(String[] args) {
// first generate the data
RandomUtil.getInstance().setSeed(1450184147770L);
// '\t';
char delimiter = ',';
final int numEdges = (int) (numVars * edgesPerNode);
List<Node> vars = new ArrayList<>();
double[] oErrors = new double[5];
double[] hsimErrors = new double[5];
double[] simErrors = new double[5];
List<double[]> output = new ArrayList<>();
for (int i = 0; i < numVars; i++) {
vars.add(new ContinuousVariable("X" + i));
}
Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
BayesPm bayesPm = new BayesPm(odag, 2, 2);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
// oData is the original data set, and odag is the original dag.
DataSet oData = bayesIm.simulateData(numCases, false);
// System.out.println(oData);
// System.out.println(odag);
// then run FGES
BDeuScore oscore = new BDeuScore(oData);
Fges fges = new Fges(oscore);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setPenaltyDiscount(penaltyDiscount);
Graph oGraphOut = fges.search();
if (verbose)
System.out.println(oGraphOut);
// calculate FGES errors
oErrors = new double[5];
oErrors = HsimUtils.errorEval(oGraphOut, odag);
if (verbose)
System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
// create various simulated data sets
// //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
PatternToDag pickdag = new PatternToDag(oGraphOut);
Graph fgesDag = pickdag.patternToDagMeek();
Dag fgesdag2 = new Dag(fgesDag);
BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
DirichletEstimator simEstimator = new DirichletEstimator();
DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
DataSet simData = fittedIM.simulateData(numCases, false);
// //next let's do a schedule of small hsims
HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
hsimErrors = study.run(resimSize, repeat);
// calculate errors for all simulated output graphs
// //full simulation errors first
BDeuScore simscore = new BDeuScore(simData);
Fges simfges = new Fges(simscore);
simfges.setVerbose(false);
simfges.setNumPatternsToStore(0);
simfges.setPenaltyDiscount(penaltyDiscount);
Graph simGraphOut = simfges.search();
// simErrors = new double[5];
simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
// first, let's just see what the errors are.
if (verbose)
System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
if (verbose)
System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
if (verbose)
System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
// then, let's try to squeeze these numbers down into something more tractable.
// double[] ErrorDifferenceDifferences;
// ErrorDifferenceDifferences = new double[5];
// ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
// ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
// ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
// ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
// ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
// System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
output.add(oErrors);
output.add(simErrors);
output.add(hsimErrors);
return output;
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestLogisticRegression method test1.
@Test
public void test1() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
System.out.println(graph);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateDataRecursive(1000, false);
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
Discretizer discretizer = new Discretizer(data);
discretizer.equalCounts(x1, 2);
DataSet d2 = discretizer.discretize();
LogisticRegression regression = new LogisticRegression(d2);
List<Node> regressors = new ArrayList<>();
regressors.add(x2);
regressors.add(x3);
regressors.add(x4);
regressors.add(x5);
DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
regression.regress(x1b, regressors);
System.out.println(regression);
}
use of edu.cmu.tetrad.data.ContinuousVariable in project tetrad by cmu-phil.
the class TestDeltaTetradTest method makePm.
private SemPm makePm() {
List<Node> variableNodes = new ArrayList<>();
ContinuousVariable x1 = new ContinuousVariable("X1");
ContinuousVariable x2 = new ContinuousVariable("X2");
ContinuousVariable x3 = new ContinuousVariable("X3");
ContinuousVariable x4 = new ContinuousVariable("X4");
ContinuousVariable x5 = new ContinuousVariable("X5");
variableNodes.add(x1);
variableNodes.add(x2);
variableNodes.add(x3);
variableNodes.add(x4);
variableNodes.add(x5);
Graph _graph = new EdgeListGraph(variableNodes);
SemGraph graph = new SemGraph(_graph);
graph.addDirectedEdge(x5, x1);
graph.addDirectedEdge(x5, x2);
graph.addDirectedEdge(x5, x3);
graph.addDirectedEdge(x5, x4);
return new SemPm(graph);
}
Aggregations