use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class TestSimulatedFmri method testClark.
// @Test
public void testClark() {
double f = .1;
int N = 512;
double alpha = 1.0;
double penaltyDiscount = 1.0;
for (int i = 0; i < 100; i++) {
{
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
Graph g = new EdgeListGraph();
g.addNode(x);
g.addNode(y);
g.addNode(z);
g.addDirectedEdge(x, y);
g.addDirectedEdge(z, x);
g.addDirectedEdge(z, y);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
try {
pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
pm.setNodeExpression(g.getNode("Z"), "E_Z");
String error = "pow(Uniform(0, 1), " + f + ")";
pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
} catch (ParseException e) {
System.out.println(e);
}
GeneralizedSemIm im = new GeneralizedSemIm(pm);
DataSet data = im.simulateData(N, false);
edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
score.setPenaltyDiscount(penaltyDiscount);
Fask fask = new Fask(data, score);
fask.setPenaltyDiscount(penaltyDiscount);
fask.setAlpha(alpha);
Graph out = fask.search();
System.out.println(out);
}
{
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
Graph g = new EdgeListGraph();
g.addNode(x);
g.addNode(y);
g.addNode(z);
g.addDirectedEdge(x, y);
g.addDirectedEdge(x, z);
g.addDirectedEdge(y, z);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
try {
pm.setNodeExpression(g.getNode("X"), "E_X");
pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
String error = "pow(Uniform(0, 1), " + f + ")";
pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
} catch (ParseException e) {
System.out.println(e);
}
GeneralizedSemIm im = new GeneralizedSemIm(pm);
DataSet data = im.simulateData(N, false);
edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
score.setPenaltyDiscount(penaltyDiscount);
Fask fask = new Fask(data, score);
fask.setPenaltyDiscount(penaltyDiscount);
fask.setAlpha(alpha);
Graph out = fask.search();
System.out.println(out);
}
}
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class GraphWrapper method getStrongestInfluenceGraph.
private static Graph getStrongestInfluenceGraph(GeneralizedSemIm im) {
GeneralizedSemPm pm = im.getGeneralizedSemPm();
Graph imGraph = im.getGeneralizedSemPm().getGraph();
List<Node> nodes = new ArrayList<>();
for (Node node : imGraph.getNodes()) {
if (!(node.getNodeType() == NodeType.ERROR)) {
nodes.add(node);
}
}
Graph graph2 = new EdgeListGraph(nodes);
for (Edge edge : imGraph.getEdges()) {
Node node1 = edge.getNode1();
Node node2 = edge.getNode2();
if (!graph2.containsNode(node1))
continue;
if (!graph2.containsNode(node2))
continue;
if (graph2.isAdjacentTo(node1, node2)) {
continue;
}
List<Edge> edges = imGraph.getEdges(node1, node2);
if (edges.size() == 1) {
graph2.addEdge(edges.get(0));
} else {
Expression expression1 = pm.getNodeExpression(node1);
Expression expression2 = pm.getNodeExpression(node2);
String param1 = findParameter(expression1, node2.getName());
String param2 = findParameter(expression2, node1.getName());
if (param1 == null || param2 == null) {
continue;
}
double value1 = im.getParameterValue(param1);
double value2 = im.getParameterValue(param2);
if (value2 > value1) {
graph2.addDirectedEdge(node1, node2);
} else if (value1 > value2) {
graph2.addDirectedEdge(node2, node1);
}
}
}
return graph2;
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class LeeHastieSimulation method simulate.
private DataSet simulate(Graph dag, Parameters parameters) {
HashMap<String, Integer> nd = new HashMap<>();
List<Node> nodes = dag.getNodes();
Collections.shuffle(nodes);
if (this.shuffledOrder == null) {
List<Node> shuffledNodes = new ArrayList<>(nodes);
Collections.shuffle(shuffledNodes);
this.shuffledOrder = shuffledNodes;
}
for (int i = 0; i < nodes.size(); i++) {
if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
final int minNumCategories = parameters.getInt("numCategories");
final int maxNumCategories = parameters.getInt("numCategories");
final int value = pickNumCategories(minNumCategories, maxNumCategories);
nd.put(shuffledOrder.get(i).getName(), value);
} else {
nd.put(shuffledOrder.get(i).getName(), 0);
}
}
Graph graph = MixedUtils.makeMixedGraph(dag, nd);
GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(graph, "Split(-1.5,-.5,.5,1.5)");
GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
DataSet ds = im.simulateDataAvoidInfinity(parameters.getInt("sampleSize"), saveLatentVars);
return MixedUtils.makeMixedData(ds, nd);
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class MixedUtils method main.
// main for testing
public static void main(String[] args) {
// Graph g = GraphConverter.convert("X1-->X2,X2-->X3,X3-->X4");
Graph g = GraphConverter.convert("X1-->X2,X2-->X3,X3-->X4, X5-->X4");
// simple graph pm im gen example
HashMap<String, Integer> nd = new HashMap<>();
nd.put("X1", 0);
nd.put("X2", 0);
nd.put("X3", 4);
nd.put("X4", 4);
nd.put("X5", 0);
g = makeMixedGraph(g, nd);
/*Graph g = new EdgeListGraph();
g.addNode(new ContinuousVariable("X1"));
g.addNode(new ContinuousVariable("X2"));
g.addNode(new DiscreteVariable("X3", 4));
g.addNode(new DiscreteVariable("X4", 4));
g.addNode(new ContinuousVariable("X5"));
g.addDirectedEdge(g.getNode("X1"), g.getNode("X2"));
g.addDirectedEdge(g.getNode("X2"), g.getNode("X3"));
g.addDirectedEdge(g.getNode("X3"), g.getNode("X4"));
g.addDirectedEdge(g.getNode("X4"), g.getNode("X5"));
*/
GeneralizedSemPm pm = GaussianCategoricalPm(g, "Split(-1.5,-1,1,1.5)");
System.out.println(pm);
System.out.println("STARTS WITH");
System.out.println(pm.getStartsWithParameterTemplate("C"));
try {
MixedUtils.setStartsWith("C", "Split(-.9,-.5,.5,.9)", pm);
} catch (Throwable t) {
t.printStackTrace();
}
System.out.println("STARTS WITH");
System.out.println(pm.getStartsWithParameterTemplate("C"));
System.out.println(pm);
GeneralizedSemIm im = GaussianCategoricalIm(pm);
System.out.println(im);
int samps = 15;
DataSet ds = im.simulateDataFisher(samps);
System.out.println(ds);
System.out.println("num cats " + ((DiscreteVariable) g.getNode("X4")).getNumCategories());
/*Graph trueGraph = DataGraphUtils.loadGraphTxt(new File(MixedUtils.class.getResource("test_data").getPath(), "DAG_0_graph.txt"));
HashMap<String, Integer> nd = new HashMap<>();
List<Node> nodes = trueGraph.getNodes();
for(int i = 0; i < nodes.size(); i++){
int coin = RandomUtil.getInstance().nextInt(2);
int dist = (coin==0) ? 0 : 3; //continuous if coin == 0
nd.put(nodes.get(i).getNode(), dist);
}
//System.out.println(getEdgeParams(g.getNode("X3"), g.getNode("X2"), pm).toString());
//System.out.println(getEdgeParams(g.getNode("X4"), g.getNode("X3"), pm).toString());
//System.out.println(getEdgeParams(g.getNode("X5"), g.getNode("X4"), pm).toString());
//System.out.println("num cats " + ((DiscreteVariable) g.getNode("X4")).getNumCategories());
/*
HashMap<String, String> nd2 = new HashMap<>();
nd2.put("X1", "Norm");
nd2.put("X2", "Norm");
nd2.put("X3", "Disc");
nd2.put("X4", "Disc");
nd2.put("X5", "Disc");
GeneralizedSemPm pm2 = GaussianTrinaryPm(g, nd2, 10, "Split(-1.5,-.5,.5,1.5)");
System.out.println("OLD pm:");
System.out.print(pm2);
*/
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class TestSimulatedFmri method testClark2.
// @Test
public void testClark2() {
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
Graph g = new EdgeListGraph();
g.addNode(x);
g.addNode(y);
g.addNode(z);
g.addDirectedEdge(x, y);
g.addDirectedEdge(x, z);
g.addDirectedEdge(y, z);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
try {
pm.setNodeExpression(g.getNode("X"), "E_X");
pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
String error = "pow(Uniform(0, 1), 1.5)";
pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
} catch (ParseException e) {
System.out.println(e);
}
GeneralizedSemIm im = new GeneralizedSemIm(pm);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
Fask fask = new Fask(data, score);
fask.setPenaltyDiscount(1);
fask.setAlpha(0.5);
Graph out = fask.search();
System.out.println(out);
}
Aggregations