Search in sources :

Example 16 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemEstimator method calcResiduals.

private static double[][] calcResiduals(double[][] data, List<Node> tierOrdering, List<String> params, double[] paramValues, GeneralizedSemPm pm, MyContext context) {
    if (pm == null)
        throw new NullPointerException();
    double[][] calculatedValues = new double[data.length][data[0].length];
    double[][] residuals = new double[data.length][data[0].length];
    for (Node node : tierOrdering) {
        context.putVariableValue(pm.getErrorNode(node).toString(), 0.0);
    }
    for (int i = 0; i < params.size(); i++) {
        context.putParameterValue(params.get(i), paramValues[i]);
    }
    ROWS: for (int row = 0; row < data.length; row++) {
        for (int j = 0; j < tierOrdering.size(); j++) {
            context.putVariableValue(tierOrdering.get(j).getName(), data[row][j]);
            if (Double.isNaN(data[row][j]))
                continue ROWS;
        }
        for (int j = 0; j < tierOrdering.size(); j++) {
            Node node = tierOrdering.get(j);
            Expression expression = pm.getNodeExpression(node);
            calculatedValues[row][j] = expression.evaluate(context);
            if (Double.isNaN(calculatedValues[row][j]))
                continue ROWS;
            residuals[row][j] = data[row][j] - calculatedValues[row][j];
        }
    }
    return residuals;
}
Also used : Expression(edu.cmu.tetrad.calculator.expression.Expression) Node(edu.cmu.tetrad.graph.Node)

Example 17 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemPm method setParametersTemplate.

public void setParametersTemplate(String parametersTemplate) throws ParseException {
    if (parametersTemplate == null) {
        throw new NullPointerException();
    }
    // Test to make sure it's parsable.
    ExpressionParser parser = new ExpressionParser();
    Expression expression = parser.parseExpression(parametersTemplate);
    List<String> parameterNames = parser.getParameters();
    if (!parameterNames.isEmpty()) {
        throw new IllegalArgumentException("Initial distribution for a parameter may not " + "contain parameters: " + expression.toString());
    }
    this.parametersTemplate = parametersTemplate;
}
Also used : Expression(edu.cmu.tetrad.calculator.expression.Expression) ExpressionParser(edu.cmu.tetrad.calculator.parser.ExpressionParser)

Example 18 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemPm method setNodeExpression.

public void setNodeExpression(Node node, String expressionString) throws ParseException {
    if (node == null) {
        throw new NullPointerException("Node was null.");
    }
    if (expressionString == null) {
        // return;
        throw new NullPointerException("Expression string was null.");
    }
    // Parse the expression. This could throw an ParseException, but that exception needs to handed up the
    // chain, because the interface will need it.
    ExpressionParser parser = new ExpressionParser();
    Expression expression = parser.parseExpression(expressionString);
    List<String> parameterNames = parser.getParameters();
    // Make a list of parent names.
    List<Node> parents = this.graph.getParents(node);
    List<String> parentNames = new LinkedList<>();
    for (Node parent : parents) {
        parentNames.add(parent.getName());
    }
    // List<String> _params = new ArrayList<String>(parameterNames);
    // _params.retainAll(variableNames);
    // _params.removeAll(parentNames);
    // 
    // if (!_params.isEmpty()) {
    // throw new IllegalArgumentException("Conditioning on a variable other than the parents: " + node);
    // }
    // Make a list of parameter names, by removing from the parser's list of freeParameters any that correspond
    // to parent variables. If there are any variable names (including error terms) that are not among the list of
    // parents, that's a time to throw an exception. We must respect the graph! (We will not complain if any parents
    // are missing.)
    parameterNames.removeAll(variableNames);
    for (Node variable : nodes) {
        if (parameterNames.contains(variable.getName())) {
            parameterNames.remove(variable.getName());
        // throw new IllegalArgumentException("The list of parameter names may not include variables: " + variable.getNode());
        }
    }
    // Remove old parameter references.
    List<String> parametersToRemove = new LinkedList<>();
    for (String parameter : this.referencedParameters.keySet()) {
        Set<Node> nodes = this.referencedParameters.get(parameter);
        if (nodes.contains(node)) {
            nodes.remove(node);
        }
        if (nodes.isEmpty()) {
            parametersToRemove.add(parameter);
        }
    }
    for (String parameter : parametersToRemove) {
        this.referencedParameters.remove(parameter);
        this.parameterExpressions.remove(parameter);
        this.parameterExpressionStrings.remove(parameter);
        this.parameterEstimationInitializationExpressions.remove(parameter);
        this.parameterEstimationInitializationExpressionStrings.remove(parameter);
    }
    // Add new parameter references.
    for (String parameter : parameterNames) {
        if (this.referencedParameters.get(parameter) == null) {
            this.referencedParameters.put(parameter, new HashSet<Node>());
        }
        Set<Node> nodes = this.referencedParameters.get(parameter);
        nodes.add(node);
        setSuitableParameterDistribution(parameter);
    }
    // Remove old node references.
    List<Node> nodesToRemove = new LinkedList<>();
    for (Node _node : this.referencedNodes.keySet()) {
        Set<Node> nodes = this.referencedNodes.get(_node);
        if (nodes.contains(node)) {
            nodes.remove(node);
        }
        if (nodes.isEmpty()) {
            nodesToRemove.add(_node);
        }
    }
    for (Node _node : nodesToRemove) {
        this.referencedNodes.remove(_node);
    }
    // Add new freeParameters.
    for (String variableString : variableNames) {
        Node _node = getNode(variableString);
        if (this.referencedNodes.get(_node) == null) {
            this.referencedNodes.put(_node, new HashSet<Node>());
        }
        for (String s : parentNames) {
            if (s.equals(variableString)) {
                Set<Node> nodes = this.referencedNodes.get(_node);
                nodes.add(node);
            }
        }
    }
    // Finally, save the parsed expression and the original string that the user entered. No need to annoy
    // the user by changing spacing.
    nodeExpressions.put(node, expression);
    nodeExpressionStrings.put(node, expressionString);
}
Also used : Expression(edu.cmu.tetrad.calculator.expression.Expression) ExpressionParser(edu.cmu.tetrad.calculator.parser.ExpressionParser)

Example 19 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataRecursive.

/**
 * This simulates data by picking random values for the exogenous terms and
 * percolating this information down through the SEM, assuming it is
 * acyclic. Fast for large simulations but hangs for cyclic models.
 *
 * @param sampleSize > 0.
 * @return the simulated data set.
 */
public DataSet simulateDataRecursive(int sampleSize, boolean latentDataSaved) {
    final Map<String, Double> variableValues = new HashMap<>();
    Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    List<Node> variables = pm.getNodes();
    List<Node> continuousVariables = new LinkedList<>();
    List<Node> nonErrorVariables = pm.getVariableNodes();
    // Work with a copy of the variables, because their type can be set externally.
    for (Node node : nonErrorVariables) {
        ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        if (var.getNodeType() != NodeType.ERROR) {
            continuousVariables.add(var);
        }
    }
    DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
    // Create some index arrays to hopefully speed up the simulation.
    SemGraph graph = pm.getGraph();
    List<Node> tierOrdering = graph.getFullTierOrdering();
    int[] tierIndices = new int[variables.size()];
    for (int i = 0; i < tierIndices.length; i++) {
        tierIndices[i] = nonErrorVariables.indexOf(tierOrdering.get(i));
    }
    int[][] _parents = new int[variables.size()][];
    for (int i = 0; i < variables.size(); i++) {
        Node node = variables.get(i);
        List<Node> parents = graph.getParents(node);
        _parents[i] = new int[parents.size()];
        for (int j = 0; j < parents.size(); j++) {
            Node _parent = parents.get(j);
            _parents[i][j] = variables.indexOf(_parent);
        }
    }
    // Do the simulation.
    for (int row = 0; row < sampleSize; row++) {
        variableValues.clear();
        for (int tier = 0; tier < tierOrdering.size(); tier++) {
            Node node = tierOrdering.get(tier);
            Expression expression = pm.getNodeExpression(node);
            double value = expression.evaluate(context);
            variableValues.put(node.getName(), value);
            int col = tierIndices[tier];
            if (col == -1) {
                continue;
            }
            // if (isSimulatePositiveDataOnly() && value < 0) {
            // row--;
            // continue ROW;
            // }
            // if (!Double.isNaN(selfLoopCoef) && row > 0) {
            // value += selfLoopCoef * fullDataSet.getDouble(row - 1, col);
            // }
            // value = min(max(value, -5.), 5.);
            fullDataSet.setDouble(row, col, value);
        }
    }
    if (latentDataSaved) {
        return fullDataSet;
    } else {
        return DataUtils.restrictToMeasured(fullDataSet);
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 20 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataFisher.

/**
 * Simulates data using the model of R. A. Fisher, for a linear model. Shocks are
 * applied every so many steps. A data point is recorded before each shock is
 * administered. If convergence happens before that number of steps has been reached,
 * a data point is recorded and a new shock immediately applied. The model may be
 * cyclic. If cyclic, all eigenvalues for the coefficient matrix must be less than 1,
 * though this is not checked.
 *
 * @param sampleSize            The number of samples to be drawn.
 * @param intervalBetweenShocks External shock is applied every this many steps.
 *                              Must be positive integer.
 * @param epsilon               The convergence criterion; |xi.t - xi.t-1| < epsilon.
 */
public DataSet simulateDataFisher(int sampleSize, int intervalBetweenShocks, double epsilon) {
    if (intervalBetweenShocks < 1)
        throw new IllegalArgumentException("Interval between shocks must be >= 1: " + intervalBetweenShocks);
    if (epsilon <= 0.0)
        throw new IllegalArgumentException("Epsilon must be > 0: " + epsilon);
    final Map<String, Double> variableValues = new HashMap<>();
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    final List<Node> variableNodes = pm.getVariableNodes();
    double[] t1 = new double[variableNodes.size()];
    double[] t2 = new double[variableNodes.size()];
    double[] shocks = new double[variableNodes.size()];
    double[][] all = new double[variableNodes.size()][sampleSize];
    // Do the simulation.
    for (int row = 0; row < sampleSize; row++) {
        for (int j = 0; j < t1.length; j++) {
            Node error = pm.getErrorNode(variableNodes.get(j));
            if (error == null) {
                throw new NullPointerException();
            }
            Expression expression = pm.getNodeExpression(error);
            double value = expression.evaluate(context);
            if (Double.isNaN(value)) {
                throw new IllegalArgumentException("Undefined value for expression: " + expression);
            }
            variableValues.put(error.getName(), value);
            shocks[j] = value;
        }
        for (int i = 0; i < intervalBetweenShocks; i++) {
            for (int j = 0; j < t1.length; j++) {
                t2[j] = shocks[j];
                Node node = variableNodes.get(j);
                Expression expression = pm.getNodeExpression(node);
                t2[j] = expression.evaluate(context);
                variableValues.put(node.getName(), t2[j]);
            }
            boolean converged = true;
            for (int j = 0; j < t1.length; j++) {
                if (Math.abs(t2[j] - t1[j]) > epsilon) {
                    converged = false;
                    break;
                }
            }
            double[] t3 = t1;
            t1 = t2;
            t2 = t3;
            if (converged) {
                break;
            }
        }
        for (int j = 0; j < t1.length; j++) {
            all[j][row] = t1[j];
        }
    }
    List<Node> continuousVars = new ArrayList<>();
    for (Node node : variableNodes) {
        final ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        continuousVars.add(var);
    }
    BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
    return DataUtils.restrictToMeasured(boxDataSet);
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Aggregations

Expression (edu.cmu.tetrad.calculator.expression.Expression)24 Context (edu.cmu.tetrad.calculator.expression.Context)10 ExpressionParser (edu.cmu.tetrad.calculator.parser.ExpressionParser)10 Test (org.junit.Test)5 ConstantExpression (edu.cmu.tetrad.calculator.expression.ConstantExpression)4 ParseException (java.text.ParseException)4 Node (edu.cmu.tetrad.graph.Node)3 ArrayList (java.util.ArrayList)3 VariableExpression (edu.cmu.tetrad.calculator.expression.VariableExpression)2 GeneralAndersonDarlingTest (edu.cmu.tetrad.data.GeneralAndersonDarlingTest)1 MultiGeneralAndersonDarlingTest (edu.cmu.tetrad.data.MultiGeneralAndersonDarlingTest)1 GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)1 List (java.util.List)1 MultivariateFunction (org.apache.commons.math3.analysis.MultivariateFunction)1 RealDistribution (org.apache.commons.math3.distribution.RealDistribution)1 InitialGuess (org.apache.commons.math3.optim.InitialGuess)1 MaxEval (org.apache.commons.math3.optim.MaxEval)1 PointValuePair (org.apache.commons.math3.optim.PointValuePair)1 MultivariateOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer)1 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)1