use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestRicf method test4.
@Test
public void test4() {
List<Node> nodes1 = new ArrayList<>();
for (int i1 = 0; i1 < 5; i1++) {
nodes1.add(new ContinuousVariable("X" + (i1 + 1)));
}
Graph g1 = GraphUtils.randomGraph(nodes1, 0, 5, 0, 0, 0, false);
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 5; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph g2 = GraphUtils.randomGraph(nodes, 0, 5, 0, 0, 0, false);
SemPm pm = new SemPm(g1);
SemIm im = new SemIm(pm);
DataSet dataset = im.simulateData(1000, false);
ICovarianceMatrix cov = new CovarianceMatrix(dataset);
new Ricf().ricf(new SemGraph(g1), cov, 0.001);
new Ricf().ricf(new SemGraph(g2), cov, 0.001);
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestIndTestWaldLR method testIsIndependent.
@Test
public void testIsIndependent() {
RandomUtil.getInstance().setSeed(1450705713157L);
int numPassed = 0;
for (int i = 0; i < 10; i++) {
List<Node> nodes = new ArrayList<>();
for (int i1 = 0; i1 < 10; i1++) {
nodes.add(new ContinuousVariable("X" + (i1 + 1)));
}
Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 3, 3, 3, false);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
Discretizer discretizer = new Discretizer(data);
discretizer.setVariablesCopied(true);
discretizer.equalCounts(data.getVariable(0), 2);
discretizer.equalCounts(data.getVariable(3), 2);
data = discretizer.discretize();
Node x1 = data.getVariable("X1");
Node x2 = data.getVariable("X2");
Node x3 = data.getVariable("X3");
Node x4 = data.getVariable("X4");
Node x5 = data.getVariable("X5");
List<Node> cond = new ArrayList<>();
cond.add(x3);
cond.add(x4);
cond.add(x5);
Node x1Graph = graph.getNode(x1.getName());
Node x2Graph = graph.getNode(x2.getName());
List<Node> condGraph = new ArrayList<>();
for (Node node : cond) {
condGraph.add(graph.getNode(node.getName()));
}
// Using the Wald LR test since it's most up to date.
IndependenceTest test = new IndTestMultinomialLogisticRegressionWald(data, 0.05, false);
IndTestDSep dsep = new IndTestDSep(graph);
boolean correct = test.isIndependent(x2, x1, cond) == dsep.isIndependent(x2Graph, x1Graph, condGraph);
if (correct) {
numPassed++;
}
}
// System.out.println(RandomUtil.getInstance().getSeed());
// Do not always get all 10.
assertEquals(10, numPassed);
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestLingamPattern method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(4938492L);
int sampleSize = 1000;
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 6; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 6, 4, 4, 4, false));
List<Distribution> variableDistributions = new ArrayList<>();
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Uniform(-1, 1));
variableDistributions.add(new Normal(0, 1));
variableDistributions.add(new Normal(0, 1));
SemPm semPm = new SemPm(graph);
SemIm semIm = new SemIm(semPm);
DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions);
Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
Graph estPattern = new Fges(score).search();
LingamPattern lingam = new LingamPattern(estPattern, dataSet);
lingam.search();
double[] pvals = lingam.getPValues();
double[] expectedPVals = { 0.18, 0.29, 0.88, 0.00, 0.01, 0.58 };
for (int i = 0; i < pvals.length; i++) {
assertEquals(expectedPVals[i], pvals[i], 0.01);
}
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestPc method printStats.
private double[] printStats(String[] algorithms, int t, boolean directed, int numRuns, double alpha, double penaltyDiscount, int numMeasures, int numLatents, double edgeFactor) {
NumberFormat nf = new DecimalFormat("0.00");
double sumArrowPrecision = 0.0;
double sumTailPrecision = 0.0;
double sumBidirectedPrecision = 0.0;
int numArrows = 0;
int numTails = 0;
int numBidirected = 0;
int count = 0;
int totalElapsed = 0;
int countAP = 0;
int countTP = 0;
int countBP = 0;
for (int i = 0; i < numRuns; i++) {
int numEdges = (int) (edgeFactor * (numMeasures + numLatents));
List<Node> nodes = new ArrayList<>();
List<String> names = new ArrayList<>();
for (int r = 0; r < numMeasures + numLatents; r++) {
String name = "X" + (r + 1);
nodes.add(new ContinuousVariable(name));
names.add(name);
}
Graph dag = GraphUtils.randomGraphRandomForwardEdges(nodes, numLatents, numEdges, 10, 10, 10, false);
SemPm pm = new SemPm(dag);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(1000, false);
IndTestFisherZ test = new IndTestFisherZ(data, alpha);
SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data));
score.setPenaltyDiscount(penaltyDiscount);
GraphSearch search;
switch(t) {
case 0:
search = new Pc(test);
break;
case 1:
search = new Cpc(test);
break;
case 2:
search = new Fges(score);
break;
case 3:
search = new Fci(test);
break;
case 4:
search = new GFci(test, score);
break;
case 5:
search = new Rfci(test);
break;
case 6:
search = new Cfci(test);
break;
default:
throw new IllegalStateException();
}
long start = System.currentTimeMillis();
Graph out = search.search();
long stop = System.currentTimeMillis();
long elapsed = stop - start;
totalElapsed += elapsed;
out = GraphUtils.replaceNodes(out, dag.getNodes());
int arrowsTp = 0;
int arrowsFp = 0;
int tailsTp = 0;
int tailsFp = 0;
int bidirectedTp = 0;
int bidirectedFp = 0;
for (Edge edge : out.getEdges()) {
if (directed && !(edge.isDirected() || Edges.isBidirectedEdge(edge))) {
continue;
}
if (edge.getEndpoint1() == Endpoint.ARROW) {
if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
arrowsTp++;
} else {
arrowsFp++;
}
numArrows++;
}
if (edge.getEndpoint2() == Endpoint.ARROW) {
if (!dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
arrowsTp++;
} else {
arrowsFp++;
}
numArrows++;
}
if (edge.getEndpoint1() == Endpoint.TAIL) {
if (dag.existsDirectedPathFromTo(edge.getNode1(), edge.getNode2())) {
tailsTp++;
} else {
tailsFp++;
}
numTails++;
}
if (edge.getEndpoint2() == Endpoint.TAIL) {
if (dag.existsDirectedPathFromTo(edge.getNode2(), edge.getNode1())) {
tailsTp++;
} else {
tailsFp++;
}
numTails++;
}
if (Edges.isBidirectedEdge(edge)) {
if (!dag.isAncestorOf(edge.getNode1(), edge.getNode2()) && !dag.isAncestorOf(edge.getNode2(), edge.getNode1()) && dag.existsTrek(edge.getNode1(), edge.getNode2())) {
bidirectedTp++;
} else {
bidirectedFp++;
}
numBidirected++;
}
}
double arrowPrecision = arrowsTp / (double) (arrowsTp + arrowsFp);
double tailPrecision = tailsTp / (double) (tailsTp + tailsFp);
double bidirectedPrecision = bidirectedTp / (double) (bidirectedTp + bidirectedFp);
if (!Double.isNaN(arrowPrecision)) {
sumArrowPrecision += arrowPrecision;
countAP++;
}
if (!Double.isNaN(tailPrecision)) {
sumTailPrecision += tailPrecision;
countTP++;
}
if (!Double.isNaN(bidirectedPrecision)) {
sumBidirectedPrecision += bidirectedPrecision;
countBP++;
}
count++;
}
double avgArrowPrecision = sumArrowPrecision / (double) countAP;
double avgTailPrecision = sumTailPrecision / (double) countTP;
double avgBidirectedPrecision = sumBidirectedPrecision / (double) countBP;
double avgNumArrows = numArrows / (double) count;
double avgNumTails = numTails / (double) count;
double avgNumBidirected = numBidirected / (double) count;
double avgElapsed = totalElapsed / (double) numRuns;
// double avgRatioPrecisionToElapsed = avgArrowPrecision / avgElapsed;
double[] ret = new double[] { avgArrowPrecision, avgTailPrecision, avgBidirectedPrecision, avgNumArrows, avgNumTails, avgNumBidirected, // minimize
-avgElapsed // avgRatioPrecisionToElapsed
};
System.out.println();
NumberFormat nf2 = new DecimalFormat("0.0000");
System.out.println(algorithms[t] + " arrow precision " + nf.format(avgArrowPrecision));
System.out.println(algorithms[t] + " tail precision " + nf.format(avgTailPrecision));
System.out.println(algorithms[t] + " bidirected precision " + nf.format(avgBidirectedPrecision));
System.out.println(algorithms[t] + " avg num arrow " + nf.format(avgNumArrows));
System.out.println(algorithms[t] + " avg num tails " + nf.format(avgNumTails));
System.out.println(algorithms[t] + " avg num bidirected " + nf.format(avgNumBidirected));
System.out.println(algorithms[t] + " avg elapsed " + nf.format(avgElapsed));
return ret;
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestStandardizedSem method rtest8.
public void rtest8() {
// RandomUtil.getInstance().setSeed(2958442283L);
SemGraph graph = new SemGraph();
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
graph.addNode(x);
graph.addNode(y);
graph.addNode(z);
graph.addDirectedEdge(x, y);
graph.addBidirectedEdge(x, y);
graph.addDirectedEdge(x, z);
graph.addDirectedEdge(y, z);
graph.setShowErrorTerms(true);
SemPm semPm = new SemPm(graph);
SemIm semIm = new SemIm(semPm);
StandardizedSemIm sem = new StandardizedSemIm(semIm, StandardizedSemIm.Initialization.CALCULATE_FROM_SEM);
DataSet data = semIm.simulateDataCholesky(1000, false);
data = ColtDataSet.makeContinuousData(data.getVariables(), DataUtils.standardizeData(data.getDoubleData()));
SemEstimator estimator = new SemEstimator(data, semPm);
semIm = estimator.estimate();
DataSet data2 = semIm.simulateDataReducedForm(1000, false);
DataSet data3 = sem.simulateDataReducedForm(1000, false);
StandardizedSemIm.ParameterRange range2 = sem.getCovarianceRange(x, y);
double high = range2.getHigh();
double low = range2.getLow();
if (high == Double.POSITIVE_INFINITY)
high = 1000;
if (low == Double.NEGATIVE_INFINITY)
low = -1000;
double coef = low + RandomUtil.getInstance().nextDouble() * (high - low);
assertTrue(sem.setErrorCovariance(x, y, coef));
assertTrue(isStandardized(sem));
}
Aggregations