use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class GeneralSemSimulationSpecial1 method simulate.
private DataSet simulate(Graph graph, Parameters parameters) {
GeneralizedSemPm pm = getPm(graph);
GeneralizedSemIm im = new GeneralizedSemIm(pm);
return im.simulateData(parameters.getInt("sampleSize"), false);
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class GeneralSemSimulationSpecial1 method getPm.
private GeneralizedSemPm getPm(Graph graph) {
GeneralizedSemPm pm = new GeneralizedSemPm(graph);
List<Node> variablesNodes = pm.getVariableNodes();
List<Node> errorNodes = pm.getErrorNodes();
Map<String, String> paramMap = new HashMap<>();
String[] funcs = { "TSUM(NEW(B)*$)", "TSUM(NEW(B)*$+NEW(C)*sin(NEW(T)*$+NEW(A)))", "TSUM(NEW(B)*(.5*$ + .5*(sqrt(abs(NEW(b)*$+NEW(exoErrorType))) ) ) )" };
paramMap.put("s", "U(1,3)");
paramMap.put("B", "Split(-1.5,-.5,.5,1.5)");
paramMap.put("C", "Split(-1.5,-.5,.5,1.5)");
paramMap.put("T", "U(.5,1.5)");
paramMap.put("A", "U(0,.25)");
paramMap.put("exoErrorType", "U(-.5,.5)");
paramMap.put("funcType", "U(1,5)");
String nonlinearStructuralEdgesFunction = funcs[0];
String nonlinearFactorMeasureEdgesFunction = funcs[0];
try {
for (Node node : variablesNodes) {
if (node.getNodeType() == NodeType.LATENT) {
String _template = TemplateExpander.getInstance().expandTemplate(nonlinearStructuralEdgesFunction, pm, node);
pm.setNodeExpression(node, _template);
} else {
String _template = TemplateExpander.getInstance().expandTemplate(nonlinearFactorMeasureEdgesFunction, pm, node);
pm.setNodeExpression(node, _template);
}
}
for (Node node : errorNodes) {
String _template = TemplateExpander.getInstance().expandTemplate("U(-.5,.5)", pm, node);
pm.setNodeExpression(node, _template);
}
Set<String> parameters = pm.getParameters();
for (String parameter : parameters) {
for (String type : paramMap.keySet()) {
if (parameter.startsWith(type)) {
pm.setParameterExpression(parameter, paramMap.get(type));
}
}
}
} catch (ParseException e) {
System.out.println(e);
}
return pm;
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class MixedUtils method GaussianCategoricalPm.
// generate PM from trueGraph for mixed Gaussian and Categorical variables
// public static GeneralizedSemPm GaussianCategoricalPm(Graph trueGraph, HashMap<String, Integer> nodeDists, String paramTemplate) throws IllegalStateException{
public static GeneralizedSemPm GaussianCategoricalPm(Graph trueGraph, String paramTemplate) throws IllegalStateException {
Map<String, Integer> nodeDists = getNodeDists(trueGraph);
GeneralizedSemPm semPm = new GeneralizedSemPm(trueGraph);
try {
List<Node> variableNodes = semPm.getVariableNodes();
int numVars = variableNodes.size();
semPm.setStartsWithParametersTemplate("B", paramTemplate);
semPm.setStartsWithParametersTemplate("C", paramTemplate);
semPm.setStartsWithParametersTemplate("D", paramTemplate);
// empirically should give us a stddev of 1 - 2
semPm.setStartsWithParametersTemplate("s", "U(1,2)");
// if we don't use NB error, we could do this instead
// String templateDisc = "DiscError(err, (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)))";
// String templateDisc0 = "DiscError(err, 1,1,1)";
String templateDisc0 = "DiscError(err, ";
for (Node node : variableNodes) {
List<Node> parents = trueGraph.getParents(node);
// System.out.println("nParents: " + parents.size() );
Node eNode = semPm.getErrorNode(node);
// normal and nb work like normal sems
String curEx = semPm.getNodeExpressionString(node);
String errEx = semPm.getNodeExpressionString(eNode);
String newTemp = "";
// System.out.println("Node: " + node + "Type: " + nodeDists.get(node));
// dist of 0 means Gaussian
int curDist = nodeDists.get(node.getName());
if (curDist == 1)
throw new IllegalArgumentException("Dist for node " + node.getName() + " is set to one (i.e. constant) which is not supported.");
// for each discrete node use DiscError for categorical draw
if (curDist > 0) {
if (parents.size() == 0) {
newTemp = "DiscError(err";
for (int l = 0; l < curDist; l++) {
newTemp += ",1";
}
newTemp += ")";
// newTemp = templateDisc0;
} else {
newTemp = "DiscError(err";
for (int l = 0; l < curDist; l++) {
newTemp += ", TSUM(NEW(C)*$)";
}
newTemp += ")";
}
newTemp = newTemp.replaceAll("err", eNode.getName());
curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
// System.out.println("Disc CurEx: " + curEx);
errEx = TemplateExpander.getInstance().expandTemplate("U(0,1)", semPm, eNode);
}
// now for every discrete parent, swap for discrete params
newTemp = curEx;
if (parents.size() != 0) {
for (Node parNode : parents) {
int parDist = nodeDists.get(parNode.getName());
if (parDist > 0) {
// String curName = trueGraph.getParents(node).get(0).toString();
String curName = parNode.getName();
String disRep = "Switch(" + curName;
for (int l = 0; l < parDist; l++) {
if (curDist > 0) {
disRep += ",NEW(D)";
} else {
disRep += ",NEW(C)";
}
}
disRep += ")";
// replaces BX * curName with new discrete expression
if (curDist > 0) {
newTemp = newTemp.replaceAll("(C[0-9]*\\*" + curName + ")(?![0-9])", disRep);
} else {
newTemp = newTemp.replaceAll("(B[0-9]*\\*" + curName + ")(?![0-9])", disRep);
}
}
}
}
if (newTemp.length() != 0) {
// System.out.println(newTemp);
curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
}
semPm.setNodeExpression(node, curEx);
semPm.setNodeExpression(eNode, errEx);
}
} catch (ParseException e) {
throw new IllegalStateException("Parse error in fixing parameters.", e);
}
return semPm;
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class MixedUtils method GaussianTrinaryPm.
// generate PM from trueGraph for mixed Gaussian and Trinary variables
// Don't use, buggy
public static GeneralizedSemPm GaussianTrinaryPm(Graph trueGraph, HashMap<String, String> nodeDists, int maxSample, String paramTemplate) throws IllegalStateException {
GeneralizedSemPm semPm = new GeneralizedSemPm(trueGraph);
try {
List<Node> variableNodes = semPm.getVariableNodes();
int numVars = variableNodes.size();
semPm.setStartsWithParametersTemplate("B", paramTemplate);
semPm.setStartsWithParametersTemplate("D", paramTemplate);
// empirically should give us a stddev of 1 - 2
semPm.setStartsWithParametersTemplate("al", "U(.3,1.3)");
semPm.setStartsWithParametersTemplate("s", "U(1,2)");
// if we don't use NB error, we could do this instead
String templateDisc = "DiscError(err, (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)), (TSUM(NEW(B)*$)))";
// String templateDisc0 = "DiscError(err, 2,2,2)";
String templateDisc0 = "DiscError(err, .001,.001,.001)";
for (Node node : variableNodes) {
List<Node> parents = trueGraph.getParents(node);
// System.out.println("nParents: " + parents.size() );
Node eNode = semPm.getErrorNode(node);
// normal and nb work like normal sems
String curEx = semPm.getNodeExpressionString(node);
String errEx = semPm.getNodeExpressionString(eNode);
String newTemp = "";
if (nodeDists.get(node.getName()).equals("Disc")) {
if (parents.size() == 0) {
newTemp = templateDisc0;
} else {
newTemp = templateDisc;
}
newTemp = newTemp.replaceAll("err", eNode.getName());
curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
// System.out.println("Disc CurEx: " + curEx);
errEx = TemplateExpander.getInstance().expandTemplate("U(0,1)", semPm, eNode);
}
// now for every discrete parent, swap for discrete params
newTemp = "";
if (parents.size() != 0) {
for (Node parNode : parents) {
if (nodeDists.get(parNode.getName()).equals("Disc")) {
// String curName = trueGraph.getParents(node).get(0).toString();
String curName = parNode.getName();
String disRep = "IF(" + curName + "=0,NEW(D),IF(" + curName + "=1,NEW(D),NEW(D)))";
newTemp = curEx.replaceAll("(B[0-9]*\\*" + curName + ")(?![0-9])", disRep);
}
}
}
if (newTemp.length() != 0) {
curEx = TemplateExpander.getInstance().expandTemplate(newTemp, semPm, node);
}
semPm.setNodeExpression(node, curEx);
semPm.setNodeExpression(eNode, errEx);
}
} catch (ParseException e) {
throw new IllegalStateException("Parse error in fixing parameters.", e);
}
return semPm;
}
use of edu.cmu.tetrad.sem.GeneralizedSemPm in project tetrad by cmu-phil.
the class MGM method runTests2.
/**
* test non penalty use cases
*/
private static void runTests2() {
Graph g = GraphConverter.convert("X1-->X2,X3-->X2,X4-->X5");
// simple graph pm im gen example
HashMap<String, Integer> nd = new HashMap<>();
nd.put("X1", 0);
nd.put("X2", 0);
nd.put("X3", 4);
nd.put("X4", 4);
nd.put("X5", 4);
g = MixedUtils.makeMixedGraph(g, nd);
GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
System.out.println(pm);
GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
System.out.println(im);
int samps = 1000;
DataSet ds = im.simulateDataFisher(samps);
ds = MixedUtils.makeMixedData(ds, nd);
// System.out.println(ds);
double lambda = 0;
MGM model = new MGM(ds, new double[] { lambda, lambda, lambda });
System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
model.learn(1e-8, 1000);
System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
System.out.println("params:\n" + model.params);
System.out.println("adjMat:\n" + model.adjMatFromMGM());
}
Aggregations