Search in sources :

Example 6 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataAvoidInfinity.

public DataSet simulateDataAvoidInfinity(int sampleSize, boolean latentDataSaved) {
    final Map<String, Double> variableValues = new HashMap<>();
    List<Node> continuousVariables = new LinkedList<>();
    final List<Node> variableNodes = pm.getVariableNodes();
    // Work with a copy of the variables, because their type can be set externally.
    for (Node node : variableNodes) {
        ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        if (var.getNodeType() != NodeType.ERROR) {
            continuousVariables.add(var);
        }
    }
    DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    boolean allInRange = true;
    // Do the simulation.
    ROW: for (int row = 0; row < sampleSize; row++) {
        // Take random draws from error distributions.
        for (Node variable : variableNodes) {
            Node error = pm.getErrorNode(variable);
            if (error == null) {
                throw new NullPointerException();
            }
            Expression expression = pm.getNodeExpression(error);
            double value;
            value = expression.evaluate(context);
            if (Double.isNaN(value)) {
                throw new IllegalArgumentException("Undefined value for expression: " + expression);
            }
            variableValues.put(error.getName(), value);
        }
        // Set the variable nodes to zero.
        for (Node variable : variableNodes) {
            Node error = pm.getErrorNode(variable);
            Expression expression = pm.getNodeExpression(error);
            double value = expression.evaluate(context);
            if (Double.isNaN(value)) {
                throw new IllegalArgumentException("Undefined value for expression: " + expression);
            }
            // value); //0.0; //RandomUtil.getInstance().nextUniform(-1, 1));
            variableValues.put(variable.getName(), 0.0);
        }
        // Repeatedly update variable values until one of them hits infinity or negative infinity or
        // convergence within delta.
        double delta = 1e-10;
        int count = -1;
        while (++count < 5000) {
            double[] values = new double[variableNodes.size()];
            for (int i = 0; i < values.length; i++) {
                Node node = variableNodes.get(i);
                Expression expression = pm.getNodeExpression(node);
                double value = expression.evaluate(context);
                values[i] = value;
            }
            allInRange = true;
            for (int i = 0; i < values.length; i++) {
                Node node = variableNodes.get(i);
                // outside of the bound (-1e6, 1e6), judge nonconvergence and pick another random starting point.
                if (!(Math.abs(variableValues.get(node.getName()) - values[i]) < delta)) {
                    if (!(Math.abs(variableValues.get(node.getName())) < 1e6)) {
                        if (count < 1000) {
                            row--;
                            continue ROW;
                        }
                    }
                    allInRange = false;
                    break;
                }
            }
            for (int i = 0; i < variableNodes.size(); i++) {
                variableValues.put(variableNodes.get(i).getName(), values[i]);
            }
            if (allInRange) {
                break;
            }
        }
        if (!allInRange) {
            if (count < 10000) {
                row--;
                System.out.println("Trying another starting point...");
                continue;
            } else {
                System.out.println("Couldn't converge in simulation.");
                for (int i = 0; i < variableNodes.size(); i++) {
                    fullDataSet.setDouble(row, i, Double.NaN);
                }
                return fullDataSet;
            }
        }
        for (int i = 0; i < variableNodes.size(); i++) {
            double value = variableValues.get(variableNodes.get(i).getName());
            if (isSimulatePositiveDataOnly() && value < 0) {
                row--;
                continue ROW;
            }
            if (!Double.isNaN(selfLoopCoef) && row > 0) {
                value += selfLoopCoef * fullDataSet.getDouble(row - 1, i);
            }
            fullDataSet.setDouble(row, i, value);
        }
    }
    if (latentDataSaved) {
        return fullDataSet;
    } else {
        return DataUtils.restrictToMeasured(fullDataSet);
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 7 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateOneRecord.

public TetradVector simulateOneRecord(TetradVector e) {
    final Map<String, Double> variableValues = new HashMap<>();
    final List<Node> variableNodes = pm.getVariableNodes();
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    // Take random draws from error distributions.
    for (int i = 0; i < variableNodes.size(); i++) {
        Node error = pm.getErrorNode(variableNodes.get(i));
        if (error == null) {
            throw new NullPointerException();
        }
        variableValues.put(error.getName(), e.get(i));
    }
    // Set the variable nodes to zero.
    for (Node variable : variableNodes) {
        // RandomUtil.getInstance().nextUniform(-5, 5));
        variableValues.put(variable.getName(), 0.0);
    }
    // Repeatedly update variable values until one of them hits infinity or negative infinity or
    // convergence within delta.
    double delta = 1e-6;
    int count = -1;
    while (++count < 10000) {
        double[] values = new double[variableNodes.size()];
        for (int i = 0; i < values.length; i++) {
            Node node = variableNodes.get(i);
            Expression expression = pm.getNodeExpression(node);
            double value = expression.evaluate(context);
            values[i] = value;
        }
        boolean allInRange = true;
        for (int i = 0; i < values.length; i++) {
            Node node = variableNodes.get(i);
            if (!(Math.abs(variableValues.get(node.getName()) - values[i]) < delta)) {
                allInRange = false;
                break;
            }
        }
        for (int i = 0; i < variableNodes.size(); i++) {
            variableValues.put(variableNodes.get(i).getName(), values[i]);
        }
        if (allInRange) {
            break;
        }
    }
    TetradVector _case = new TetradVector(e.size());
    for (int i = 0; i < variableNodes.size(); i++) {
        double value = variableValues.get(variableNodes.get(i).getName());
        _case.set(i, value);
    }
    return _case;
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 8 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemEstimator method calcOneResiduals.

private static double[] calcOneResiduals(int index, double[][] data, List<Node> tierOrdering, List<String> params, double[] values, GeneralizedSemPm pm, MyContext context) {
    if (pm == null)
        throw new NullPointerException();
    double[] residuals = new double[data.length];
    for (Node node : tierOrdering) {
        context.putVariableValue(pm.getErrorNode(node).toString(), 0.0);
    }
    for (int i = 0; i < params.size(); i++) {
        context.putParameterValue(params.get(i), values[i]);
    }
    for (int row = 0; row < data.length; row++) {
        for (int i = 0; i < tierOrdering.size(); i++) {
            context.putVariableValue(tierOrdering.get(i).getName(), data[row][i]);
        }
        Node node = tierOrdering.get(index);
        Expression expression = pm.getNodeExpression(node);
        double calculatedValue = expression.evaluate(context);
        residuals[row] = data[row][index] - calculatedValue;
    }
    return residuals;
}
Also used : Expression(edu.cmu.tetrad.calculator.expression.Expression) Node(edu.cmu.tetrad.graph.Node)

Example 9 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemEstimator method estimate.

/**
 * Maximizes likelihood equation by equation. Assumes the equations are recursive and that
 * each has exactly one error term.
 */
public GeneralizedSemIm estimate(GeneralizedSemPm pm, DataSet data) {
    StringBuilder builder = new StringBuilder();
    GeneralizedSemIm estIm = new GeneralizedSemIm(pm);
    List<Node> nodes = pm.getGraph().getNodes();
    nodes.removeAll(pm.getErrorNodes());
    MyContext context = new MyContext();
    List<List<Double>> allResiduals = new ArrayList<>();
    List<RealDistribution> allDistributions = new ArrayList<>();
    for (int index = 0; index < nodes.size(); index++) {
        Node node = nodes.get(index);
        List<String> parameters = new ArrayList<>(pm.getReferencedParameters(node));
        Node error = pm.getErrorNode(node);
        parameters.addAll(pm.getReferencedParameters(error));
        LikelihoodFittingFunction2 likelihoodFittingfunction = new LikelihoodFittingFunction2(index, pm, parameters, nodes, data, context);
        double[] values = new double[parameters.size()];
        for (int j = 0; j < parameters.size(); j++) {
            String parameter = parameters.get(j);
            Expression parameterEstimationInitializationExpression = pm.getParameterEstimationInitializationExpression(parameter);
            values[j] = parameterEstimationInitializationExpression.evaluate(new MyContext());
        }
        double[] point = optimize(likelihoodFittingfunction, values, 1);
        for (int j = 0; j < parameters.size(); j++) {
            estIm.setParameterValue(parameters.get(j), point[j]);
        }
        List<Double> residuals = likelihoodFittingfunction.getResiduals();
        allResiduals.add(residuals);
        RealDistribution distribution = likelihoodFittingfunction.getDistribution();
        allDistributions.add(distribution);
        GeneralAndersonDarlingTest test = new GeneralAndersonDarlingTest(residuals, distribution);
        builder.append("\nEquation: ").append(node).append(" := ").append(estIm.getNodeSubstitutedString(node));
        builder.append("\n\twhere ").append(pm.getErrorNode(node)).append(" ~ ").append(estIm.getNodeSubstitutedString(pm.getErrorNode(node)));
        builder.append("\nAnderson Darling A^2* for this equation =  ").append(test.getASquaredStar()).append("\n");
    }
    List<String> parameters = new ArrayList<>();
    double[] values = new double[parameters.size()];
    for (int i = 0; i < parameters.size(); i++) {
        values[i] = estIm.getParameterValue(parameters.get(i));
    }
    LikelihoodFittingFunction likelihoodFittingFunction = new LikelihoodFittingFunction(pm, parameters, nodes, data, context);
    optimize(likelihoodFittingFunction, values, 1);
    MultiGeneralAndersonDarlingTest test = new MultiGeneralAndersonDarlingTest(allResiduals, allDistributions);
    double aSquaredStar = test.getASquaredStar();
    this.aSquaredStar = aSquaredStar;
    String builder2 = "Report:\n" + "\nModel A^2* (Anderson Darling) = " + aSquaredStar + "\n" + builder;
    this.report = builder2;
    return estIm;
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) MultiGeneralAndersonDarlingTest(edu.cmu.tetrad.data.MultiGeneralAndersonDarlingTest) RealDistribution(org.apache.commons.math3.distribution.RealDistribution) Expression(edu.cmu.tetrad.calculator.expression.Expression) GeneralAndersonDarlingTest(edu.cmu.tetrad.data.GeneralAndersonDarlingTest) MultiGeneralAndersonDarlingTest(edu.cmu.tetrad.data.MultiGeneralAndersonDarlingTest) ArrayList(java.util.ArrayList) List(java.util.List)

Example 10 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemPm method setParameterExpression.

/**
 * Sets the expression which should be evaluated when calculating new values for the given
 * parameter. These values are used to initialize the freeParameters.
 * @param parameter The parameter whose initial value needs to be computed.
 * @param expressionString The formula for picking initial values.
 * @throws ParseException If the formula cannot be parsed or contains variable names.
 */
public void setParameterExpression(String startsWith, String parameter, String expressionString) throws ParseException {
    if (parameter == null) {
        throw new NullPointerException("Parameter was null.");
    }
    if (startsWith == null) {
        throw new NullPointerException("StartsWith expression was null.");
    }
    if (startsWith.contains(" ")) {
        throw new IllegalArgumentException("StartsWith expression contains spaces.");
    }
    if (expressionString == null) {
        throw new NullPointerException("Expression string was null.");
    }
    // Parse the expression. This could throw an ParseException, but that exception needs to handed up the
    // chain, because the interface will need it.
    ExpressionParser parser = new ExpressionParser();
    Expression expression = parser.parseExpression(expressionString);
    List<String> parameterNames = parser.getParameters();
    if (parameterNames.size() > 0) {
        throw new IllegalArgumentException("Initial distribution may not " + "contain parameters: " + expressionString);
    }
    parameterExpressions.put(parameter, expression);
    parameterExpressionStrings.put(parameter, expressionString);
    startsWithParametersTemplates.put(startsWith, expressionString);
}
Also used : Expression(edu.cmu.tetrad.calculator.expression.Expression) ExpressionParser(edu.cmu.tetrad.calculator.parser.ExpressionParser)

Aggregations

Expression (edu.cmu.tetrad.calculator.expression.Expression)24 Context (edu.cmu.tetrad.calculator.expression.Context)10 ExpressionParser (edu.cmu.tetrad.calculator.parser.ExpressionParser)10 Test (org.junit.Test)5 ConstantExpression (edu.cmu.tetrad.calculator.expression.ConstantExpression)4 ParseException (java.text.ParseException)4 Node (edu.cmu.tetrad.graph.Node)3 ArrayList (java.util.ArrayList)3 VariableExpression (edu.cmu.tetrad.calculator.expression.VariableExpression)2 GeneralAndersonDarlingTest (edu.cmu.tetrad.data.GeneralAndersonDarlingTest)1 MultiGeneralAndersonDarlingTest (edu.cmu.tetrad.data.MultiGeneralAndersonDarlingTest)1 GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)1 List (java.util.List)1 MultivariateFunction (org.apache.commons.math3.analysis.MultivariateFunction)1 RealDistribution (org.apache.commons.math3.distribution.RealDistribution)1 InitialGuess (org.apache.commons.math3.optim.InitialGuess)1 MaxEval (org.apache.commons.math3.optim.MaxEval)1 PointValuePair (org.apache.commons.math3.optim.PointValuePair)1 MultivariateOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer)1 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)1