use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.
the class GeneralizedSemIm method simulateOneRecord.
public TetradVector simulateOneRecord(TetradVector e) {
final Map<String, Double> variableValues = new HashMap<>();
final List<Node> variableNodes = pm.getVariableNodes();
final Context context = new Context() {
public Double getValue(String term) {
Double value = parameterValues.get(term);
if (value != null) {
return value;
}
value = variableValues.get(term);
if (value != null) {
return value;
}
throw new IllegalArgumentException("No value recorded for '" + term + "'");
}
};
// Take random draws from error distributions.
for (int i = 0; i < variableNodes.size(); i++) {
Node error = pm.getErrorNode(variableNodes.get(i));
if (error == null) {
throw new NullPointerException();
}
variableValues.put(error.getName(), e.get(i));
}
// Set the variable nodes to zero.
for (Node variable : variableNodes) {
// RandomUtil.getInstance().nextUniform(-5, 5));
variableValues.put(variable.getName(), 0.0);
}
// Repeatedly update variable values until one of them hits infinity or negative infinity or
// convergence within delta.
double delta = 1e-6;
int count = -1;
while (++count < 10000) {
double[] values = new double[variableNodes.size()];
for (int i = 0; i < values.length; i++) {
Node node = variableNodes.get(i);
Expression expression = pm.getNodeExpression(node);
double value = expression.evaluate(context);
values[i] = value;
}
boolean allInRange = true;
for (int i = 0; i < values.length; i++) {
Node node = variableNodes.get(i);
if (!(Math.abs(variableValues.get(node.getName()) - values[i]) < delta)) {
allInRange = false;
break;
}
}
for (int i = 0; i < variableNodes.size(); i++) {
variableValues.put(variableNodes.get(i).getName(), values[i]);
}
if (allInRange) {
break;
}
}
TetradVector _case = new TetradVector(e.size());
for (int i = 0; i < variableNodes.size(); i++) {
double value = variableValues.get(variableNodes.get(i).getName());
_case.set(i, value);
}
return _case;
}
use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.
the class GeneralizedSemIm method simulateDataRecursive.
/**
* This simulates data by picking random values for the exogenous terms and
* percolating this information down through the SEM, assuming it is
* acyclic. Fast for large simulations but hangs for cyclic models.
*
* @param sampleSize > 0.
* @return the simulated data set.
*/
public DataSet simulateDataRecursive(int sampleSize, boolean latentDataSaved) {
final Map<String, Double> variableValues = new HashMap<>();
Context context = new Context() {
public Double getValue(String term) {
Double value = parameterValues.get(term);
if (value != null) {
return value;
}
value = variableValues.get(term);
if (value != null) {
return value;
}
throw new IllegalArgumentException("No value recorded for '" + term + "'");
}
};
List<Node> variables = pm.getNodes();
List<Node> continuousVariables = new LinkedList<>();
List<Node> nonErrorVariables = pm.getVariableNodes();
// Work with a copy of the variables, because their type can be set externally.
for (Node node : nonErrorVariables) {
ContinuousVariable var = new ContinuousVariable(node.getName());
var.setNodeType(node.getNodeType());
if (var.getNodeType() != NodeType.ERROR) {
continuousVariables.add(var);
}
}
DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
// Create some index arrays to hopefully speed up the simulation.
SemGraph graph = pm.getGraph();
List<Node> tierOrdering = graph.getFullTierOrdering();
int[] tierIndices = new int[variables.size()];
for (int i = 0; i < tierIndices.length; i++) {
tierIndices[i] = nonErrorVariables.indexOf(tierOrdering.get(i));
}
int[][] _parents = new int[variables.size()][];
for (int i = 0; i < variables.size(); i++) {
Node node = variables.get(i);
List<Node> parents = graph.getParents(node);
_parents[i] = new int[parents.size()];
for (int j = 0; j < parents.size(); j++) {
Node _parent = parents.get(j);
_parents[i][j] = variables.indexOf(_parent);
}
}
// Do the simulation.
for (int row = 0; row < sampleSize; row++) {
variableValues.clear();
for (int tier = 0; tier < tierOrdering.size(); tier++) {
Node node = tierOrdering.get(tier);
Expression expression = pm.getNodeExpression(node);
double value = expression.evaluate(context);
variableValues.put(node.getName(), value);
int col = tierIndices[tier];
if (col == -1) {
continue;
}
// if (isSimulatePositiveDataOnly() && value < 0) {
// row--;
// continue ROW;
// }
// if (!Double.isNaN(selfLoopCoef) && row > 0) {
// value += selfLoopCoef * fullDataSet.getDouble(row - 1, col);
// }
// value = min(max(value, -5.), 5.);
fullDataSet.setDouble(row, col, value);
}
}
if (latentDataSaved) {
return fullDataSet;
} else {
return DataUtils.restrictToMeasured(fullDataSet);
}
}
use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.
the class GeneralizedSemIm method simulateDataFisher.
/**
* Simulates data using the model of R. A. Fisher, for a linear model. Shocks are
* applied every so many steps. A data point is recorded before each shock is
* administered. If convergence happens before that number of steps has been reached,
* a data point is recorded and a new shock immediately applied. The model may be
* cyclic. If cyclic, all eigenvalues for the coefficient matrix must be less than 1,
* though this is not checked.
*
* @param sampleSize The number of samples to be drawn.
* @param intervalBetweenShocks External shock is applied every this many steps.
* Must be positive integer.
* @param epsilon The convergence criterion; |xi.t - xi.t-1| < epsilon.
*/
public DataSet simulateDataFisher(int sampleSize, int intervalBetweenShocks, double epsilon) {
if (intervalBetweenShocks < 1)
throw new IllegalArgumentException("Interval between shocks must be >= 1: " + intervalBetweenShocks);
if (epsilon <= 0.0)
throw new IllegalArgumentException("Epsilon must be > 0: " + epsilon);
final Map<String, Double> variableValues = new HashMap<>();
final Context context = new Context() {
public Double getValue(String term) {
Double value = parameterValues.get(term);
if (value != null) {
return value;
}
value = variableValues.get(term);
if (value != null) {
return value;
}
throw new IllegalArgumentException("No value recorded for '" + term + "'");
}
};
final List<Node> variableNodes = pm.getVariableNodes();
double[] t1 = new double[variableNodes.size()];
double[] t2 = new double[variableNodes.size()];
double[] shocks = new double[variableNodes.size()];
double[][] all = new double[variableNodes.size()][sampleSize];
// Do the simulation.
for (int row = 0; row < sampleSize; row++) {
for (int j = 0; j < t1.length; j++) {
Node error = pm.getErrorNode(variableNodes.get(j));
if (error == null) {
throw new NullPointerException();
}
Expression expression = pm.getNodeExpression(error);
double value = expression.evaluate(context);
if (Double.isNaN(value)) {
throw new IllegalArgumentException("Undefined value for expression: " + expression);
}
variableValues.put(error.getName(), value);
shocks[j] = value;
}
for (int i = 0; i < intervalBetweenShocks; i++) {
for (int j = 0; j < t1.length; j++) {
t2[j] = shocks[j];
Node node = variableNodes.get(j);
Expression expression = pm.getNodeExpression(node);
t2[j] = expression.evaluate(context);
variableValues.put(node.getName(), t2[j]);
}
boolean converged = true;
for (int j = 0; j < t1.length; j++) {
if (Math.abs(t2[j] - t1[j]) > epsilon) {
converged = false;
break;
}
}
double[] t3 = t1;
t1 = t2;
t2 = t3;
if (converged) {
break;
}
}
for (int j = 0; j < t1.length; j++) {
all[j][row] = t1[j];
}
}
List<Node> continuousVars = new ArrayList<>();
for (Node node : variableNodes) {
final ContinuousVariable var = new ContinuousVariable(node.getName());
var.setNodeType(node.getNodeType());
continuousVars.add(var);
}
BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
return DataUtils.restrictToMeasured(boxDataSet);
}
use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.
the class TestExpressionParser method test2.
public void test2() {
final Map<String, Double> values = new HashMap<>();
values.put("b11", 1.0);
values.put("X1", 2.0);
values.put("X2", 3.0);
values.put("B22", 4.0);
values.put("B12", 5.0);
values.put("X4", 6.0);
values.put("b13", 7.0);
values.put("X5", 8.0);
values.put("b10", 9.0);
values.put("X", 10.0);
values.put("Y", 11.0);
values.put("Z", 12.0);
values.put("W", 13.0);
values.put("T", 14.0);
values.put("R", 15.0);
values.put("s2", 0.0);
values.put("s3", 1.0);
Context context = new Context() {
public Double getValue(String var) {
return values.get(var);
}
};
List<String> formulas = new ArrayList<>();
//
formulas.add("ChiSquare(s3)");
formulas.add("Gamma(1, 1)");
formulas.add("Beta(3, 5)");
formulas.add("Poisson(5)");
formulas.add("Indicator(0.3)");
formulas.add("ExponentialPower(3)");
// Log normal
formulas.add("exp(Normal(s2, s3))");
formulas.add("Normal(0, s3)");
// Gaussian Power
formulas.add("abs(Normal(s2, s3) ^ 3)");
formulas.add("Discrete(3, 1, 5)");
// Mixture of Gaussians
formulas.add("0.3 * Normal(-2.0e2, 0.5) + 0.7 * Normal(2.0, 0.5)");
formulas.add("StudentT(s3)");
// Single value.
formulas.add("s3");
formulas.add("Hyperbolic(5, 3)");
formulas.add("Uniform(s2, s3)");
formulas.add("VonMises(s3)");
formulas.add("Split(0, 1, 5, 6)");
formulas.add("Mixture(0.5, N(-2, 0.5), 0.5, N(2, 0.5))");
//
ExpressionParser parser = new ExpressionParser();
try {
for (String formula : formulas) {
Expression expression = parser.parseExpression(formula);
double value = expression.evaluate(context);
}
} catch (ParseException e) {
e.printStackTrace();
}
}
use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.
the class TestExpressionParser method test1.
@Test
public void test1() {
final Map<String, Double> values = new HashMap<>();
values.put("b11", 1.0);
values.put("X1", 2.0);
values.put("X2", 3.0);
values.put("B22", 4.0);
values.put("B12", 5.0);
values.put("X4", 6.0);
values.put("b13", 7.0);
values.put("X5", 8.0);
values.put("b10", 9.0);
values.put("X", 10.0);
values.put("Y", 11.0);
values.put("Z", 12.0);
values.put("W", 13.0);
values.put("T", 14.0);
values.put("R", 15.0);
Context context = new Context() {
public Double getValue(String var) {
return values.get(var);
}
};
Map<String, Double> formulasToEvaluations = new HashMap<>();
formulasToEvaluations.put("0", 0.0);
formulasToEvaluations.put("b11*X1 + sin(X2) + B22*X2 + B12*X4+b13*X5", 100.14);
formulasToEvaluations.put("X5*X4*X4", 288.0);
formulasToEvaluations.put("sin(b10*X1)", -0.75097);
formulasToEvaluations.put("((X + ((Y * (Z ^ W)) * T)) + R)", 16476953628377113.0);
formulasToEvaluations.put("X + Y * Z ^ W * T + R", 16476953628377113.0);
formulasToEvaluations.put("pow(2, 5)", 32.0);
formulasToEvaluations.put("2^5", 32.0);
formulasToEvaluations.put("exp(1)", 2.718);
formulasToEvaluations.put("sqrt(2)", 1.414);
formulasToEvaluations.put("cos(0)", 1.0);
formulasToEvaluations.put("cos(3.14/2)", 0.0);
formulasToEvaluations.put("sin(0)", 0.0);
formulasToEvaluations.put("sin(3.14/2)", 1.0);
formulasToEvaluations.put("tan(1)", 1.56);
formulasToEvaluations.put("cosh(1)", 1.54);
formulasToEvaluations.put("sinh(1)", 1.18);
formulasToEvaluations.put("tanh(1)", 0.76);
formulasToEvaluations.put("acos(1)", 0.0);
formulasToEvaluations.put("asin(1)", 1.57);
formulasToEvaluations.put("atan(1)", 0.78);
formulasToEvaluations.put("ln(1)", 0.0);
formulasToEvaluations.put("log10(10)", 1.0);
formulasToEvaluations.put("ceil(2.5)", 3.0);
formulasToEvaluations.put("floor(2.5)", 2.0);
formulasToEvaluations.put("abs(-5)", 5.0);
formulasToEvaluations.put("max(2, 5, 3, 1, 10, -3)", 10.0);
formulasToEvaluations.put("min(2, 5, 3, 1, 10, -3)", -3.0);
// Logical.
formulasToEvaluations.put("AND(1, 1)", 1.0);
formulasToEvaluations.put("AND(1, 0)", 0.0);
formulasToEvaluations.put("AND(0, 1)", 0.0);
formulasToEvaluations.put("AND(0, 0)", 0.0);
formulasToEvaluations.put("AND(0, 0.5)", 0.0);
formulasToEvaluations.put("1 AND 1", 1.0);
formulasToEvaluations.put("OR(1, 1)", 1.0);
formulasToEvaluations.put("OR(1, 0)", 1.0);
formulasToEvaluations.put("OR(0, 1)", 1.0);
formulasToEvaluations.put("OR(0, 0)", 0.0);
formulasToEvaluations.put("OR(0, 0.5)", 0.0);
formulasToEvaluations.put("1 OR 1", 1.0);
formulasToEvaluations.put("XOR(1, 1)", 0.0);
formulasToEvaluations.put("XOR(1, 0)", 1.0);
formulasToEvaluations.put("XOR(0, 1)", 1.0);
formulasToEvaluations.put("XOR(0, 0)", 0.0);
formulasToEvaluations.put("XOR(0, 0.5)", 0.0);
formulasToEvaluations.put("1 XOR 1", 0.0);
formulasToEvaluations.put("1 AND 0 OR 1 XOR 1 + 1", 1.0);
formulasToEvaluations.put("1 < 2", 1.0);
formulasToEvaluations.put("1 < 0", 0.0);
formulasToEvaluations.put("1 < 1", 0.0);
formulasToEvaluations.put("1 <= 2", 1.0);
formulasToEvaluations.put("1 <= 1", 1.0);
formulasToEvaluations.put("1 <= -1", 0.0);
formulasToEvaluations.put("1 = 2", 0.0);
formulasToEvaluations.put("1 = 1", 1.0);
formulasToEvaluations.put("1 = -1", 0.0);
formulasToEvaluations.put("1 > 2", 0.0);
formulasToEvaluations.put("1 > 1", 0.0);
formulasToEvaluations.put("1 > -1", 1.0);
formulasToEvaluations.put("1 >= 2", 0.0);
formulasToEvaluations.put("1 >= 1", 1.0);
formulasToEvaluations.put("1 >= -1", 1.0);
formulasToEvaluations.put("IF(1 > 2, 1, 2)", 2.0);
formulasToEvaluations.put("IF(1 < 2, 1, 2)", 1.0);
formulasToEvaluations.put("IF(1 < 2 AND 3 < 4, 1, 2)", 1.0);
ExpressionParser parser = new ExpressionParser();
try {
for (String formula : formulasToEvaluations.keySet()) {
Expression expression = parser.parseExpression(formula);
double value = expression.evaluate(context);
assertEquals(formulasToEvaluations.get(formula), value, 0.01);
}
} catch (ParseException e) {
e.printStackTrace();
}
}
Aggregations