use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.
the class GeneralSemSimulationSpecial1 method simulate.
private DataSet simulate(Graph graph, Parameters parameters) {
GeneralizedSemPm pm = getPm(graph);
GeneralizedSemIm im = new GeneralizedSemIm(pm);
return im.simulateData(parameters.getInt("sampleSize"), false);
}
use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.
the class MixedUtils method GaussianCategoricalIm.
/**
* This method is needed to normalize edge parameters for an Instantiated Mixed Model
* Generates edge parameters for c-d and d-d edges from a single weight, abs(w), drawn by the normal IM constructor.
* Abs(w) is used for d-d edges.
*
* For deterministic, c-d are evenly spaced between -w and w, and d-d are a matrix with w on the diagonal and
* -w/(categories-1) in the rest.
* For random, c-d params are uniformly drawn from 0 to 1 then transformed to have w as max value and sum to 0.
*
* @param pm
* @param discParamRand true for random edge generation behavior, false for deterministic
* @return
*/
public static GeneralizedSemIm GaussianCategoricalIm(GeneralizedSemPm pm, boolean discParamRand) {
Map<String, Integer> nodeDists = getNodeDists(pm.getGraph());
GeneralizedSemIm im = new GeneralizedSemIm(pm);
// System.out.println(im);
List<Node> nodes = pm.getVariableNodes();
// this needs to be changed for cyclic graphs...
for (Node n : nodes) {
Set<Node> parNodes = pm.getReferencedNodes(n);
if (parNodes.size() == 0) {
continue;
}
for (Node par : parNodes) {
if (par.getNodeType() == NodeType.ERROR) {
continue;
}
int cL = nodeDists.get(n.getName());
int pL = nodeDists.get(par.getName());
// c-c edges don't need params changed
if (cL == 0 && pL == 0) {
continue;
}
List<String> params = getEdgeParams(n, par, pm);
// just use the first parameter as the "weight" for the whole edge
double w = im.getParameterValue(params.get(0));
// d-d edges use one vector and permute edges, could use different strategy
if (cL > 0 && pL > 0) {
double[][] newWeights = new double[cL][pL];
// List<Integer> indices = new ArrayList<Integer>(pL);
// PermutationGenerator pg = new PermutationGenerator(pL);
// int[] permInd = pg.next();
w = Math.abs(w);
double bgW = w / ((double) pL - 1.0);
double[] weightVals;
/*if(discParamRand)
weightVals = generateMixedEdgeParams(w, pL);
else
weightVals = evenSplitVector(w, pL);
*/
int[] weightInds = new int[cL];
for (int i = 0; i < cL; i++) {
if (i < pL)
weightInds[i] = i;
else
weightInds[i] = i % pL;
}
if (discParamRand)
weightInds = arrayPermute(weightInds);
for (int i = 0; i < cL; i++) {
for (int j = 0; j < pL; j++) {
int index = i * pL + j;
if (weightInds[i] == j)
im.setParameterValue(params.get(index), w);
else
im.setParameterValue(params.get(index), -bgW);
}
}
// params for c-d edges
} else {
double[] newWeights;
int curL = (pL > 0 ? pL : cL);
if (discParamRand)
newWeights = generateMixedEdgeParams(w, curL);
else
newWeights = evenSplitVector(w, curL);
int count = 0;
for (String p : params) {
im.setParameterValue(p, newWeights[count]);
count++;
}
}
}
// pm.
// if(p.startsWith("B")){
// continue;
// } else if(p.startsWith())
}
return im;
}
use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.
the class MGM method runTests2.
/**
* test non penalty use cases
*/
private static void runTests2() {
Graph g = GraphConverter.convert("X1-->X2,X3-->X2,X4-->X5");
// simple graph pm im gen example
HashMap<String, Integer> nd = new HashMap<>();
nd.put("X1", 0);
nd.put("X2", 0);
nd.put("X3", 4);
nd.put("X4", 4);
nd.put("X5", 4);
g = MixedUtils.makeMixedGraph(g, nd);
GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
System.out.println(pm);
GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
System.out.println(im);
int samps = 1000;
DataSet ds = im.simulateDataFisher(samps);
ds = MixedUtils.makeMixedData(ds, nd);
// System.out.println(ds);
double lambda = 0;
MGM model = new MGM(ds, new double[] { lambda, lambda, lambda });
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
model.learn(1e-8, 1000);
System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
}
use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.
the class TestSimulatedFmri method testClark.
// @Test
public void testClark() {
double f = .1;
int N = 512;
double alpha = 1.0;
double penaltyDiscount = 1.0;
for (int i = 0; i < 100; i++) {
{
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
Graph g = new EdgeListGraph();
g.addNode(x);
g.addNode(y);
g.addNode(z);
g.addDirectedEdge(x, y);
g.addDirectedEdge(z, x);
g.addDirectedEdge(z, y);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
try {
pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
pm.setNodeExpression(g.getNode("Z"), "E_Z");
String error = "pow(Uniform(0, 1), " + f + ")";
pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
} catch (ParseException e) {
System.out.println(e);
}
GeneralizedSemIm im = new GeneralizedSemIm(pm);
DataSet data = im.simulateData(N, false);
edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
score.setPenaltyDiscount(penaltyDiscount);
Fask fask = new Fask(data, score);
fask.setPenaltyDiscount(penaltyDiscount);
fask.setAlpha(alpha);
Graph out = fask.search();
System.out.println(out);
}
{
Node x = new ContinuousVariable("X");
Node y = new ContinuousVariable("Y");
Node z = new ContinuousVariable("Z");
Graph g = new EdgeListGraph();
g.addNode(x);
g.addNode(y);
g.addNode(z);
g.addDirectedEdge(x, y);
g.addDirectedEdge(x, z);
g.addDirectedEdge(y, z);
GeneralizedSemPm pm = new GeneralizedSemPm(g);
try {
pm.setNodeExpression(g.getNode("X"), "E_X");
pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
String error = "pow(Uniform(0, 1), " + f + ")";
pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
} catch (ParseException e) {
System.out.println(e);
}
GeneralizedSemIm im = new GeneralizedSemIm(pm);
DataSet data = im.simulateData(N, false);
edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
score.setPenaltyDiscount(penaltyDiscount);
Fask fask = new Fask(data, score);
fask.setPenaltyDiscount(penaltyDiscount);
fask.setAlpha(alpha);
Graph out = fask.search();
System.out.println(out);
}
}
}
use of edu.cmu.tetrad.sem.GeneralizedSemIm in project tetrad by cmu-phil.
the class LeeHastieSimulation method simulate.
private DataSet simulate(Graph dag, Parameters parameters) {
HashMap<String, Integer> nd = new HashMap<>();
List<Node> nodes = dag.getNodes();
Collections.shuffle(nodes);
if (this.shuffledOrder == null) {
List<Node> shuffledNodes = new ArrayList<>(nodes);
Collections.shuffle(shuffledNodes);
this.shuffledOrder = shuffledNodes;
}
for (int i = 0; i < nodes.size(); i++) {
if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
final int minNumCategories = parameters.getInt("numCategories");
final int maxNumCategories = parameters.getInt("numCategories");
final int value = pickNumCategories(minNumCategories, maxNumCategories);
nd.put(shuffledOrder.get(i).getName(), value);
} else {
nd.put(shuffledOrder.get(i).getName(), 0);
}
}
Graph graph = MixedUtils.makeMixedGraph(dag, nd);
GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(graph, "Split(-1.5,-.5,.5,1.5)");
GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
DataSet ds = im.simulateDataAvoidInfinity(parameters.getInt("sampleSize"), saveLatentVars);
return MixedUtils.makeMixedData(ds, nd);
}
Aggregations