Search in sources :

Example 1 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class TestExpressionParser method test5.

// Test distribution means.
@Test
public void test5() {
    final Map<String, Double> values = new HashMap<>();
    Context context = new Context() {

        public Double getValue(String var) {
            return values.get(var);
        }
    };
    Map<String, Double> formulas = new LinkedHashMap<>();
    formulas.put("ChiSquare(1)", 1.0);
    formulas.put("Gamma(2, .5)", 1.0);
    formulas.put("Beta(1, 2)", 0.33);
    formulas.put("Normal(2, 3)", 2.0);
    formulas.put("N(2, 3)", 2.0);
    formulas.put("StudentT(5)", 0.0);
    formulas.put("U(0, 1)", 0.5);
    formulas.put("Uniform(0, 1)", 0.5);
    formulas.put("Split(0, 1, 5, 6)", 3.0);
    ExpressionParser parser = new ExpressionParser();
    try {
        for (String formula : formulas.keySet()) {
            Expression expression = parser.parseExpression(formula);
            double sum = 0.0;
            int sampleSize = 10000;
            for (int i = 0; i < sampleSize; i++) {
                double value = expression.evaluate(context);
                sum += value;
            }
            double mean = sum / sampleSize;
            assertEquals(formulas.get(formula), mean, 0.1);
        }
    } catch (ParseException e) {
        e.printStackTrace();
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression) ExpressionParser(edu.cmu.tetrad.calculator.parser.ExpressionParser) ParseException(java.text.ParseException) Test(org.junit.Test)

Example 2 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class EmpiricalDistributionForExpression method getDist.

public RealDistribution getDist() {
    List<Double> drawFromDistribution = new ArrayList<>();
    Expression expression = semPm.getNodeExpression(error);
    for (int k = 0; k < 5000; k++) {
        double evaluate = expression.evaluate(context);
        drawFromDistribution.add(evaluate);
    }
    return new EmpiricalCdf(drawFromDistribution);
}
Also used : Expression(edu.cmu.tetrad.calculator.expression.Expression) ArrayList(java.util.ArrayList)

Example 3 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataNSteps.

public DataSet simulateDataNSteps(int sampleSize, boolean latentDataSaved) {
    final Map<String, Double> variableValues = new HashMap<>();
    List<Node> continuousVariables = new LinkedList<>();
    final List<Node> variableNodes = pm.getVariableNodes();
    // Work with a copy of the variables, because their type can be set externally.
    for (Node node : variableNodes) {
        ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        if (var.getNodeType() != NodeType.ERROR) {
            continuousVariables.add(var);
        }
    }
    DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    // Do the simulation.
    ROW: for (int row = 0; row < sampleSize; row++) {
        // Take random draws from error distributions.
        for (Node variable : variableNodes) {
            Node error = pm.getErrorNode(variable);
            if (error == null) {
                throw new NullPointerException();
            }
            Expression expression = pm.getNodeExpression(error);
            double value = expression.evaluate(context);
            if (Double.isNaN(value)) {
                throw new IllegalArgumentException("Undefined value for expression: " + expression);
            }
            variableValues.put(error.getName(), value);
        }
        // Set the variable nodes to zero.
        for (Node variable : variableNodes) {
            // RandomUtil.getInstance().nextUniform(-5, 5));
            variableValues.put(variable.getName(), 0.0);
        }
        for (int m = 0; m < 1; m++) {
            double[] values = new double[variableNodes.size()];
            for (int i = 0; i < values.length; i++) {
                Node node = variableNodes.get(i);
                Expression expression = pm.getNodeExpression(node);
                double value = expression.evaluate(context);
                if (Double.isNaN(value)) {
                    throw new IllegalArgumentException("Undefined value for expression: " + expression);
                }
                values[i] = value;
            }
            for (double value : values) {
                if (value == Double.POSITIVE_INFINITY || value == Double.NEGATIVE_INFINITY) {
                    row--;
                    continue ROW;
                }
            }
            for (int i = 0; i < variableNodes.size(); i++) {
                variableValues.put(variableNodes.get(i).getName(), values[i]);
            }
        }
        for (int i = 0; i < variableNodes.size(); i++) {
            double value = variableValues.get(variableNodes.get(i).getName());
            fullDataSet.setDouble(row, i, value);
        }
    }
    if (latentDataSaved) {
        return fullDataSet;
    } else {
        return DataUtils.restrictToMeasured(fullDataSet);
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 4 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataMinimizeSurface.

public DataSet simulateDataMinimizeSurface(int sampleSize, boolean latentDataSaved) {
    final Map<String, Double> variableValues = new HashMap<>();
    List<Node> continuousVariables = new LinkedList<>();
    final List<Node> variableNodes = pm.getVariableNodes();
    // Work with a copy of the variables, because their type can be set externally.
    for (Node node : variableNodes) {
        ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        if (var.getNodeType() != NodeType.ERROR) {
            continuousVariables.add(var);
        }
    }
    DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    final double[] _metric = new double[1];
    MultivariateFunction function = new MultivariateFunction() {

        double metric;

        public double value(double[] doubles) {
            for (int i = 0; i < variableNodes.size(); i++) {
                variableValues.put(variableNodes.get(i).getName(), doubles[i]);
            }
            double[] image = new double[doubles.length];
            for (int i = 0; i < variableNodes.size(); i++) {
                Node node = variableNodes.get(i);
                Expression expression = pm.getNodeExpression(node);
                image[i] = expression.evaluate(context);
                if (Double.isNaN(image[i])) {
                    throw new IllegalArgumentException("Undefined value for expression " + expression);
                }
            }
            metric = 0.0;
            for (int i = 0; i < variableNodes.size(); i++) {
                double diff = doubles[i] - image[i];
                metric += diff * diff;
            }
            for (int i = 0; i < variableNodes.size(); i++) {
                variableValues.put(variableNodes.get(i).getName(), image[i]);
            }
            _metric[0] = metric;
            return metric;
        }
    };
    MultivariateOptimizer search = new PowellOptimizer(1e-7, 1e-7);
    // Do the simulation.
    ROW: for (int row = 0; row < sampleSize; row++) {
        // Take random draws from error distributions.
        for (Node variable : variableNodes) {
            Node error = pm.getErrorNode(variable);
            if (error == null) {
                throw new NullPointerException();
            }
            Expression expression = pm.getNodeExpression(error);
            double value = expression.evaluate(context);
            if (Double.isNaN(value)) {
                throw new IllegalArgumentException("Undefined value for expression: " + expression);
            }
            variableValues.put(error.getName(), value);
        }
        for (Node variable : variableNodes) {
            // RandomUtil.getInstance().nextUniform(-5, 5));
            variableValues.put(variable.getName(), 0.0);
        }
        while (true) {
            double[] values = new double[variableNodes.size()];
            for (int i = 0; i < values.length; i++) {
                values[i] = variableValues.get(variableNodes.get(i).getName());
            }
            PointValuePair pair = search.optimize(new InitialGuess(values), new ObjectiveFunction(function), GoalType.MINIMIZE, new MaxEval(100000));
            values = pair.getPoint();
            for (int i = 0; i < variableNodes.size(); i++) {
                if (isSimulatePositiveDataOnly() && values[i] < 0) {
                    row--;
                    continue ROW;
                }
                if (!Double.isNaN(selfLoopCoef) && row > 0) {
                    values[i] += selfLoopCoef * fullDataSet.getDouble(row - 1, i);
                }
                variableValues.put(variableNodes.get(i).getName(), values[i]);
                fullDataSet.setDouble(row, i, values[i]);
            }
            if (_metric[0] < 0.01) {
                // while
                break;
            }
        }
    }
    if (latentDataSaved) {
        return fullDataSet;
    } else {
        return DataUtils.restrictToMeasured(fullDataSet);
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) MultivariateOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer) InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) PowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer) PointValuePair(org.apache.commons.math3.optim.PointValuePair) MultivariateFunction(org.apache.commons.math3.analysis.MultivariateFunction) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 5 with Expression

use of edu.cmu.tetrad.calculator.expression.Expression in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateTimeSeries.

private DataSet simulateTimeSeries(int sampleSize) {
    SemGraph semGraph = new SemGraph(getSemPm().getGraph());
    semGraph.setShowErrorTerms(true);
    TimeLagGraph timeLagGraph = getSemPm().getGraph().getTimeLagGraph();
    List<Node> variables = new ArrayList<>();
    for (Node node : timeLagGraph.getLag0Nodes()) {
        if (node.getNodeType() == NodeType.ERROR)
            continue;
        variables.add(new ContinuousVariable(timeLagGraph.getNodeId(node).getName()));
    }
    List<Node> lag0Nodes = timeLagGraph.getLag0Nodes();
    for (Node node : new ArrayList<>(lag0Nodes)) {
        if (node.getNodeType() == NodeType.ERROR) {
            lag0Nodes.remove(node);
        }
    }
    DataSet fullData = new ColtDataSet(sampleSize, variables);
    Map<Node, Integer> nodeIndices = new HashMap<>();
    for (int i = 0; i < lag0Nodes.size(); i++) {
        nodeIndices.put(lag0Nodes.get(i), i);
    }
    Graph contemporaneousDag = timeLagGraph.subgraph(timeLagGraph.getLag0Nodes());
    List<Node> tierOrdering = contemporaneousDag.getCausalOrdering();
    for (Node node : new ArrayList<>(tierOrdering)) {
        if (node.getNodeType() == NodeType.ERROR) {
            tierOrdering.remove(node);
        }
    }
    final Map<String, Double> variableValues = new HashMap<>();
    Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            } else {
                return RandomUtil.getInstance().nextNormal(0, 1);
            }
        }
    };
    ROW: for (int currentStep = 0; currentStep < sampleSize; currentStep++) {
        for (Node node : tierOrdering) {
            Expression expression = pm.getNodeExpression(node);
            double value = expression.evaluate(context);
            if (isSimulatePositiveDataOnly() && value < 0) {
                currentStep--;
                continue ROW;
            }
            int col = nodeIndices.get(node);
            fullData.setDouble(currentStep, col, value);
            variableValues.put(node.getName(), value);
        }
        for (Node node : lag0Nodes) {
            TimeLagGraph.NodeId _id = timeLagGraph.getNodeId(node);
            for (int lag = 1; lag <= timeLagGraph.getMaxLag(); lag++) {
                Node _node = timeLagGraph.getNode(_id.getName(), lag);
                int col = lag0Nodes.indexOf(node);
                if (_node == null) {
                    continue;
                }
                if (currentStep - lag + 1 >= 0) {
                    double _value = fullData.getDouble((currentStep - lag + 1), col);
                    variableValues.put(_node.getName(), _value);
                }
            }
        }
    }
    return fullData;
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Aggregations

Expression (edu.cmu.tetrad.calculator.expression.Expression)24 Context (edu.cmu.tetrad.calculator.expression.Context)10 ExpressionParser (edu.cmu.tetrad.calculator.parser.ExpressionParser)10 Test (org.junit.Test)5 ConstantExpression (edu.cmu.tetrad.calculator.expression.ConstantExpression)4 ParseException (java.text.ParseException)4 Node (edu.cmu.tetrad.graph.Node)3 ArrayList (java.util.ArrayList)3 VariableExpression (edu.cmu.tetrad.calculator.expression.VariableExpression)2 GeneralAndersonDarlingTest (edu.cmu.tetrad.data.GeneralAndersonDarlingTest)1 MultiGeneralAndersonDarlingTest (edu.cmu.tetrad.data.MultiGeneralAndersonDarlingTest)1 GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)1 List (java.util.List)1 MultivariateFunction (org.apache.commons.math3.analysis.MultivariateFunction)1 RealDistribution (org.apache.commons.math3.distribution.RealDistribution)1 InitialGuess (org.apache.commons.math3.optim.InitialGuess)1 MaxEval (org.apache.commons.math3.optim.MaxEval)1 PointValuePair (org.apache.commons.math3.optim.PointValuePair)1 MultivariateOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer)1 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)1