use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.
the class GeneralizedSemEstimator method calcResiduals.
private static double[][] calcResiduals(double[][] data, List<Node> tierOrdering, List<String> params, double[] paramValues, GeneralizedSemPm pm, MyContext context) {
if (pm == null)
throw new NullPointerException();
double[][] calculatedValues = new double[data.length][data[0].length];
double[][] residuals = new double[data.length][data[0].length];
for (Node node : tierOrdering) {
context.putVariableValue(pm.getErrorNode(node).toString(), 0.0);
}
for (int i = 0; i < params.size(); i++) {
context.putParameterValue(params.get(i), paramValues[i]);
}
ROWS: for (int row = 0; row < data.length; row++) {
for (int j = 0; j < tierOrdering.size(); j++) {
context.putVariableValue(tierOrdering.get(j).getName(), data[row][j]);
if (Double.isNaN(data[row][j]))
continue ROWS;
}
for (int j = 0; j < tierOrdering.size(); j++) {
Node node = tierOrdering.get(j);
Expression expression = pm.getNodeExpression(node);
calculatedValues[row][j] = expression.evaluate(context);
if (Double.isNaN(calculatedValues[row][j]))
continue ROWS;
residuals[row][j] = data[row][j] - calculatedValues[row][j];
}
}
return residuals;
}
use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.
the class GeneralizedSemPm method setParametersTemplate.
public void setParametersTemplate(String parametersTemplate) throws ParseException {
if (parametersTemplate == null) {
throw new NullPointerException();
}
// Test to make sure it's parsable.
ExpressionParser parser = new ExpressionParser();
Expression expression = parser.parseExpression(parametersTemplate);
List<String> parameterNames = parser.getParameters();
if (!parameterNames.isEmpty()) {
throw new IllegalArgumentException("Initial distribution for a parameter may not " + "contain parameters: " + expression.toString());
}
this.parametersTemplate = parametersTemplate;
}
use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.
the class GeneralizedSemPm method setNodeExpression.
public void setNodeExpression(Node node, String expressionString) throws ParseException {
if (node == null) {
throw new NullPointerException("Node was null.");
}
if (expressionString == null) {
// return;
throw new NullPointerException("Expression string was null.");
}
// Parse the expression. This could throw an ParseException, but that exception needs to handed up the
// chain, because the interface will need it.
ExpressionParser parser = new ExpressionParser();
Expression expression = parser.parseExpression(expressionString);
List<String> parameterNames = parser.getParameters();
// Make a list of parent names.
List<Node> parents = this.graph.getParents(node);
List<String> parentNames = new LinkedList<>();
for (Node parent : parents) {
parentNames.add(parent.getName());
}
// List<String> _params = new ArrayList<String>(parameterNames);
// _params.retainAll(variableNames);
// _params.removeAll(parentNames);
//
// if (!_params.isEmpty()) {
// throw new IllegalArgumentException("Conditioning on a variable other than the parents: " + node);
// }
// Make a list of parameter names, by removing from the parser's list of freeParameters any that correspond
// to parent variables. If there are any variable names (including error terms) that are not among the list of
// parents, that's a time to throw an exception. We must respect the graph! (We will not complain if any parents
// are missing.)
parameterNames.removeAll(variableNames);
for (Node variable : nodes) {
if (parameterNames.contains(variable.getName())) {
parameterNames.remove(variable.getName());
// throw new IllegalArgumentException("The list of parameter names may not include variables: " + variable.getNode());
}
}
// Remove old parameter references.
List<String> parametersToRemove = new LinkedList<>();
for (String parameter : this.referencedParameters.keySet()) {
Set<Node> nodes = this.referencedParameters.get(parameter);
if (nodes.contains(node)) {
nodes.remove(node);
}
if (nodes.isEmpty()) {
parametersToRemove.add(parameter);
}
}
for (String parameter : parametersToRemove) {
this.referencedParameters.remove(parameter);
this.parameterExpressions.remove(parameter);
this.parameterExpressionStrings.remove(parameter);
this.parameterEstimationInitializationExpressions.remove(parameter);
this.parameterEstimationInitializationExpressionStrings.remove(parameter);
}
// Add new parameter references.
for (String parameter : parameterNames) {
if (this.referencedParameters.get(parameter) == null) {
this.referencedParameters.put(parameter, new HashSet<Node>());
}
Set<Node> nodes = this.referencedParameters.get(parameter);
nodes.add(node);
setSuitableParameterDistribution(parameter);
}
// Remove old node references.
List<Node> nodesToRemove = new LinkedList<>();
for (Node _node : this.referencedNodes.keySet()) {
Set<Node> nodes = this.referencedNodes.get(_node);
if (nodes.contains(node)) {
nodes.remove(node);
}
if (nodes.isEmpty()) {
nodesToRemove.add(_node);
}
}
for (Node _node : nodesToRemove) {
this.referencedNodes.remove(_node);
}
// Add new freeParameters.
for (String variableString : variableNames) {
Node _node = getNode(variableString);
if (this.referencedNodes.get(_node) == null) {
this.referencedNodes.put(_node, new HashSet<Node>());
}
for (String s : parentNames) {
if (s.equals(variableString)) {
Set<Node> nodes = this.referencedNodes.get(_node);
nodes.add(node);
}
}
}
// Finally, save the parsed expression and the original string that the user entered. No need to annoy
// the user by changing spacing.
nodeExpressions.put(node, expression);
nodeExpressionStrings.put(node, expressionString);
}
use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.
the class GeneralizedSemIm method simulateDataRecursive.
/**
* This simulates data by picking random values for the exogenous terms and
* percolating this information down through the SEM, assuming it is
* acyclic. Fast for large simulations but hangs for cyclic models.
*
* @param sampleSize > 0.
* @return the simulated data set.
*/
public DataSet simulateDataRecursive(int sampleSize, boolean latentDataSaved) {
final Map<String, Double> variableValues = new HashMap<>();
Context context = new Context() {
public Double getValue(String term) {
Double value = parameterValues.get(term);
if (value != null) {
return value;
}
value = variableValues.get(term);
if (value != null) {
return value;
}
throw new IllegalArgumentException("No value recorded for '" + term + "'");
}
};
List<Node> variables = pm.getNodes();
List<Node> continuousVariables = new LinkedList<>();
List<Node> nonErrorVariables = pm.getVariableNodes();
// Work with a copy of the variables, because their type can be set externally.
for (Node node : nonErrorVariables) {
ContinuousVariable var = new ContinuousVariable(node.getName());
var.setNodeType(node.getNodeType());
if (var.getNodeType() != NodeType.ERROR) {
continuousVariables.add(var);
}
}
DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
// Create some index arrays to hopefully speed up the simulation.
SemGraph graph = pm.getGraph();
List<Node> tierOrdering = graph.getFullTierOrdering();
int[] tierIndices = new int[variables.size()];
for (int i = 0; i < tierIndices.length; i++) {
tierIndices[i] = nonErrorVariables.indexOf(tierOrdering.get(i));
}
int[][] _parents = new int[variables.size()][];
for (int i = 0; i < variables.size(); i++) {
Node node = variables.get(i);
List<Node> parents = graph.getParents(node);
_parents[i] = new int[parents.size()];
for (int j = 0; j < parents.size(); j++) {
Node _parent = parents.get(j);
_parents[i][j] = variables.indexOf(_parent);
}
}
// Do the simulation.
for (int row = 0; row < sampleSize; row++) {
variableValues.clear();
for (int tier = 0; tier < tierOrdering.size(); tier++) {
Node node = tierOrdering.get(tier);
Expression expression = pm.getNodeExpression(node);
double value = expression.evaluate(context);
variableValues.put(node.getName(), value);
int col = tierIndices[tier];
if (col == -1) {
continue;
}
// if (isSimulatePositiveDataOnly() && value < 0) {
// row--;
// continue ROW;
// }
// if (!Double.isNaN(selfLoopCoef) && row > 0) {
// value += selfLoopCoef * fullDataSet.getDouble(row - 1, col);
// }
// value = min(max(value, -5.), 5.);
fullDataSet.setDouble(row, col, value);
}
}
if (latentDataSaved) {
return fullDataSet;
} else {
return DataUtils.restrictToMeasured(fullDataSet);
}
}
use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.
the class GeneralizedSemIm method simulateDataFisher.
/**
* Simulates data using the model of R. A. Fisher, for a linear model. Shocks are
* applied every so many steps. A data point is recorded before each shock is
* administered. If convergence happens before that number of steps has been reached,
* a data point is recorded and a new shock immediately applied. The model may be
* cyclic. If cyclic, all eigenvalues for the coefficient matrix must be less than 1,
* though this is not checked.
*
* @param sampleSize The number of samples to be drawn.
* @param intervalBetweenShocks External shock is applied every this many steps.
* Must be positive integer.
* @param epsilon The convergence criterion; |xi.t - xi.t-1| < epsilon.
*/
public DataSet simulateDataFisher(int sampleSize, int intervalBetweenShocks, double epsilon) {
if (intervalBetweenShocks < 1)
throw new IllegalArgumentException("Interval between shocks must be >= 1: " + intervalBetweenShocks);
if (epsilon <= 0.0)
throw new IllegalArgumentException("Epsilon must be > 0: " + epsilon);
final Map<String, Double> variableValues = new HashMap<>();
final Context context = new Context() {
public Double getValue(String term) {
Double value = parameterValues.get(term);
if (value != null) {
return value;
}
value = variableValues.get(term);
if (value != null) {
return value;
}
throw new IllegalArgumentException("No value recorded for '" + term + "'");
}
};
final List<Node> variableNodes = pm.getVariableNodes();
double[] t1 = new double[variableNodes.size()];
double[] t2 = new double[variableNodes.size()];
double[] shocks = new double[variableNodes.size()];
double[][] all = new double[variableNodes.size()][sampleSize];
// Do the simulation.
for (int row = 0; row < sampleSize; row++) {
for (int j = 0; j < t1.length; j++) {
Node error = pm.getErrorNode(variableNodes.get(j));
if (error == null) {
throw new NullPointerException();
}
Expression expression = pm.getNodeExpression(error);
double value = expression.evaluate(context);
if (Double.isNaN(value)) {
throw new IllegalArgumentException("Undefined value for expression: " + expression);
}
variableValues.put(error.getName(), value);
shocks[j] = value;
}
for (int i = 0; i < intervalBetweenShocks; i++) {
for (int j = 0; j < t1.length; j++) {
t2[j] = shocks[j];
Node node = variableNodes.get(j);
Expression expression = pm.getNodeExpression(node);
t2[j] = expression.evaluate(context);
variableValues.put(node.getName(), t2[j]);
}
boolean converged = true;
for (int j = 0; j < t1.length; j++) {
if (Math.abs(t2[j] - t1[j]) > epsilon) {
converged = false;
break;
}
}
double[] t3 = t1;
t1 = t2;
t2 = t3;
if (converged) {
break;
}
}
for (int j = 0; j < t1.length; j++) {
all[j][row] = t1[j];
}
}
List<Node> continuousVars = new ArrayList<>();
for (Node node : variableNodes) {
final ContinuousVariable var = new ContinuousVariable(node.getName());
var.setNodeType(node.getNodeType());
continuousVars.add(var);
}
BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
return DataUtils.restrictToMeasured(boxDataSet);
}
Aggregations