Search in sources :

Example 6 with Context

use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateOneRecord.

public TetradVector simulateOneRecord(TetradVector e) {
    final Map<String, Double> variableValues = new HashMap<>();
    final List<Node> variableNodes = pm.getVariableNodes();
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    // Take random draws from error distributions.
    for (int i = 0; i < variableNodes.size(); i++) {
        Node error = pm.getErrorNode(variableNodes.get(i));
        if (error == null) {
            throw new NullPointerException();
        }
        variableValues.put(error.getName(), e.get(i));
    }
    // Set the variable nodes to zero.
    for (Node variable : variableNodes) {
        // RandomUtil.getInstance().nextUniform(-5, 5));
        variableValues.put(variable.getName(), 0.0);
    }
    // Repeatedly update variable values until one of them hits infinity or negative infinity or
    // convergence within delta.
    double delta = 1e-6;
    int count = -1;
    while (++count < 10000) {
        double[] values = new double[variableNodes.size()];
        for (int i = 0; i < values.length; i++) {
            Node node = variableNodes.get(i);
            Expression expression = pm.getNodeExpression(node);
            double value = expression.evaluate(context);
            values[i] = value;
        }
        boolean allInRange = true;
        for (int i = 0; i < values.length; i++) {
            Node node = variableNodes.get(i);
            if (!(Math.abs(variableValues.get(node.getName()) - values[i]) < delta)) {
                allInRange = false;
                break;
            }
        }
        for (int i = 0; i < variableNodes.size(); i++) {
            variableValues.put(variableNodes.get(i).getName(), values[i]);
        }
        if (allInRange) {
            break;
        }
    }
    TetradVector _case = new TetradVector(e.size());
    for (int i = 0; i < variableNodes.size(); i++) {
        double value = variableValues.get(variableNodes.get(i).getName());
        _case.set(i, value);
    }
    return _case;
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 7 with Context

use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataRecursive.

/**
 * This simulates data by picking random values for the exogenous terms and
 * percolating this information down through the SEM, assuming it is
 * acyclic. Fast for large simulations but hangs for cyclic models.
 *
 * @param sampleSize > 0.
 * @return the simulated data set.
 */
public DataSet simulateDataRecursive(int sampleSize, boolean latentDataSaved) {
    final Map<String, Double> variableValues = new HashMap<>();
    Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    List<Node> variables = pm.getNodes();
    List<Node> continuousVariables = new LinkedList<>();
    List<Node> nonErrorVariables = pm.getVariableNodes();
    // Work with a copy of the variables, because their type can be set externally.
    for (Node node : nonErrorVariables) {
        ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        if (var.getNodeType() != NodeType.ERROR) {
            continuousVariables.add(var);
        }
    }
    DataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
    // Create some index arrays to hopefully speed up the simulation.
    SemGraph graph = pm.getGraph();
    List<Node> tierOrdering = graph.getFullTierOrdering();
    int[] tierIndices = new int[variables.size()];
    for (int i = 0; i < tierIndices.length; i++) {
        tierIndices[i] = nonErrorVariables.indexOf(tierOrdering.get(i));
    }
    int[][] _parents = new int[variables.size()][];
    for (int i = 0; i < variables.size(); i++) {
        Node node = variables.get(i);
        List<Node> parents = graph.getParents(node);
        _parents[i] = new int[parents.size()];
        for (int j = 0; j < parents.size(); j++) {
            Node _parent = parents.get(j);
            _parents[i][j] = variables.indexOf(_parent);
        }
    }
    // Do the simulation.
    for (int row = 0; row < sampleSize; row++) {
        variableValues.clear();
        for (int tier = 0; tier < tierOrdering.size(); tier++) {
            Node node = tierOrdering.get(tier);
            Expression expression = pm.getNodeExpression(node);
            double value = expression.evaluate(context);
            variableValues.put(node.getName(), value);
            int col = tierIndices[tier];
            if (col == -1) {
                continue;
            }
            // if (isSimulatePositiveDataOnly() && value < 0) {
            // row--;
            // continue ROW;
            // }
            // if (!Double.isNaN(selfLoopCoef) && row > 0) {
            // value += selfLoopCoef * fullDataSet.getDouble(row - 1, col);
            // }
            // value = min(max(value, -5.), 5.);
            fullDataSet.setDouble(row, col, value);
        }
    }
    if (latentDataSaved) {
        return fullDataSet;
    } else {
        return DataUtils.restrictToMeasured(fullDataSet);
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 8 with Context

use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.

the class GeneralizedSemIm method simulateDataFisher.

/**
 * Simulates data using the model of R. A. Fisher, for a linear model. Shocks are
 * applied every so many steps. A data point is recorded before each shock is
 * administered. If convergence happens before that number of steps has been reached,
 * a data point is recorded and a new shock immediately applied. The model may be
 * cyclic. If cyclic, all eigenvalues for the coefficient matrix must be less than 1,
 * though this is not checked.
 *
 * @param sampleSize            The number of samples to be drawn.
 * @param intervalBetweenShocks External shock is applied every this many steps.
 *                              Must be positive integer.
 * @param epsilon               The convergence criterion; |xi.t - xi.t-1| < epsilon.
 */
public DataSet simulateDataFisher(int sampleSize, int intervalBetweenShocks, double epsilon) {
    if (intervalBetweenShocks < 1)
        throw new IllegalArgumentException("Interval between shocks must be >= 1: " + intervalBetweenShocks);
    if (epsilon <= 0.0)
        throw new IllegalArgumentException("Epsilon must be > 0: " + epsilon);
    final Map<String, Double> variableValues = new HashMap<>();
    final Context context = new Context() {

        public Double getValue(String term) {
            Double value = parameterValues.get(term);
            if (value != null) {
                return value;
            }
            value = variableValues.get(term);
            if (value != null) {
                return value;
            }
            throw new IllegalArgumentException("No value recorded for '" + term + "'");
        }
    };
    final List<Node> variableNodes = pm.getVariableNodes();
    double[] t1 = new double[variableNodes.size()];
    double[] t2 = new double[variableNodes.size()];
    double[] shocks = new double[variableNodes.size()];
    double[][] all = new double[variableNodes.size()][sampleSize];
    // Do the simulation.
    for (int row = 0; row < sampleSize; row++) {
        for (int j = 0; j < t1.length; j++) {
            Node error = pm.getErrorNode(variableNodes.get(j));
            if (error == null) {
                throw new NullPointerException();
            }
            Expression expression = pm.getNodeExpression(error);
            double value = expression.evaluate(context);
            if (Double.isNaN(value)) {
                throw new IllegalArgumentException("Undefined value for expression: " + expression);
            }
            variableValues.put(error.getName(), value);
            shocks[j] = value;
        }
        for (int i = 0; i < intervalBetweenShocks; i++) {
            for (int j = 0; j < t1.length; j++) {
                t2[j] = shocks[j];
                Node node = variableNodes.get(j);
                Expression expression = pm.getNodeExpression(node);
                t2[j] = expression.evaluate(context);
                variableValues.put(node.getName(), t2[j]);
            }
            boolean converged = true;
            for (int j = 0; j < t1.length; j++) {
                if (Math.abs(t2[j] - t1[j]) > epsilon) {
                    converged = false;
                    break;
                }
            }
            double[] t3 = t1;
            t1 = t2;
            t2 = t3;
            if (converged) {
                break;
            }
        }
        for (int j = 0; j < t1.length; j++) {
            all[j][row] = t1[j];
        }
    }
    List<Node> continuousVars = new ArrayList<>();
    for (Node node : variableNodes) {
        final ContinuousVariable var = new ContinuousVariable(node.getName());
        var.setNodeType(node.getNodeType());
        continuousVars.add(var);
    }
    BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
    return DataUtils.restrictToMeasured(boxDataSet);
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression)

Example 9 with Context

use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.

the class TestExpressionParser method test2.

public void test2() {
    final Map<String, Double> values = new HashMap<>();
    values.put("b11", 1.0);
    values.put("X1", 2.0);
    values.put("X2", 3.0);
    values.put("B22", 4.0);
    values.put("B12", 5.0);
    values.put("X4", 6.0);
    values.put("b13", 7.0);
    values.put("X5", 8.0);
    values.put("b10", 9.0);
    values.put("X", 10.0);
    values.put("Y", 11.0);
    values.put("Z", 12.0);
    values.put("W", 13.0);
    values.put("T", 14.0);
    values.put("R", 15.0);
    values.put("s2", 0.0);
    values.put("s3", 1.0);
    Context context = new Context() {

        public Double getValue(String var) {
            return values.get(var);
        }
    };
    List<String> formulas = new ArrayList<>();
    // 
    formulas.add("ChiSquare(s3)");
    formulas.add("Gamma(1, 1)");
    formulas.add("Beta(3, 5)");
    formulas.add("Poisson(5)");
    formulas.add("Indicator(0.3)");
    formulas.add("ExponentialPower(3)");
    // Log normal
    formulas.add("exp(Normal(s2, s3))");
    formulas.add("Normal(0, s3)");
    // Gaussian Power
    formulas.add("abs(Normal(s2, s3) ^ 3)");
    formulas.add("Discrete(3, 1, 5)");
    // Mixture of Gaussians
    formulas.add("0.3 * Normal(-2.0e2, 0.5) + 0.7 * Normal(2.0, 0.5)");
    formulas.add("StudentT(s3)");
    // Single value.
    formulas.add("s3");
    formulas.add("Hyperbolic(5, 3)");
    formulas.add("Uniform(s2, s3)");
    formulas.add("VonMises(s3)");
    formulas.add("Split(0, 1, 5, 6)");
    formulas.add("Mixture(0.5, N(-2, 0.5), 0.5, N(2, 0.5))");
    // 
    ExpressionParser parser = new ExpressionParser();
    try {
        for (String formula : formulas) {
            Expression expression = parser.parseExpression(formula);
            double value = expression.evaluate(context);
        }
    } catch (ParseException e) {
        e.printStackTrace();
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression) ExpressionParser(edu.cmu.tetrad.calculator.parser.ExpressionParser) ParseException(java.text.ParseException)

Example 10 with Context

use of edu.cmu.tetrad.calculator.expression.Context in project tetrad by cmu-phil.

the class TestExpressionParser method test1.

@Test
public void test1() {
    final Map<String, Double> values = new HashMap<>();
    values.put("b11", 1.0);
    values.put("X1", 2.0);
    values.put("X2", 3.0);
    values.put("B22", 4.0);
    values.put("B12", 5.0);
    values.put("X4", 6.0);
    values.put("b13", 7.0);
    values.put("X5", 8.0);
    values.put("b10", 9.0);
    values.put("X", 10.0);
    values.put("Y", 11.0);
    values.put("Z", 12.0);
    values.put("W", 13.0);
    values.put("T", 14.0);
    values.put("R", 15.0);
    Context context = new Context() {

        public Double getValue(String var) {
            return values.get(var);
        }
    };
    Map<String, Double> formulasToEvaluations = new HashMap<>();
    formulasToEvaluations.put("0", 0.0);
    formulasToEvaluations.put("b11*X1 + sin(X2) + B22*X2 + B12*X4+b13*X5", 100.14);
    formulasToEvaluations.put("X5*X4*X4", 288.0);
    formulasToEvaluations.put("sin(b10*X1)", -0.75097);
    formulasToEvaluations.put("((X + ((Y * (Z ^ W)) * T)) + R)", 16476953628377113.0);
    formulasToEvaluations.put("X + Y * Z ^ W * T + R", 16476953628377113.0);
    formulasToEvaluations.put("pow(2, 5)", 32.0);
    formulasToEvaluations.put("2^5", 32.0);
    formulasToEvaluations.put("exp(1)", 2.718);
    formulasToEvaluations.put("sqrt(2)", 1.414);
    formulasToEvaluations.put("cos(0)", 1.0);
    formulasToEvaluations.put("cos(3.14/2)", 0.0);
    formulasToEvaluations.put("sin(0)", 0.0);
    formulasToEvaluations.put("sin(3.14/2)", 1.0);
    formulasToEvaluations.put("tan(1)", 1.56);
    formulasToEvaluations.put("cosh(1)", 1.54);
    formulasToEvaluations.put("sinh(1)", 1.18);
    formulasToEvaluations.put("tanh(1)", 0.76);
    formulasToEvaluations.put("acos(1)", 0.0);
    formulasToEvaluations.put("asin(1)", 1.57);
    formulasToEvaluations.put("atan(1)", 0.78);
    formulasToEvaluations.put("ln(1)", 0.0);
    formulasToEvaluations.put("log10(10)", 1.0);
    formulasToEvaluations.put("ceil(2.5)", 3.0);
    formulasToEvaluations.put("floor(2.5)", 2.0);
    formulasToEvaluations.put("abs(-5)", 5.0);
    formulasToEvaluations.put("max(2, 5, 3, 1, 10, -3)", 10.0);
    formulasToEvaluations.put("min(2, 5, 3, 1, 10, -3)", -3.0);
    // Logical.
    formulasToEvaluations.put("AND(1, 1)", 1.0);
    formulasToEvaluations.put("AND(1, 0)", 0.0);
    formulasToEvaluations.put("AND(0, 1)", 0.0);
    formulasToEvaluations.put("AND(0, 0)", 0.0);
    formulasToEvaluations.put("AND(0, 0.5)", 0.0);
    formulasToEvaluations.put("1 AND 1", 1.0);
    formulasToEvaluations.put("OR(1, 1)", 1.0);
    formulasToEvaluations.put("OR(1, 0)", 1.0);
    formulasToEvaluations.put("OR(0, 1)", 1.0);
    formulasToEvaluations.put("OR(0, 0)", 0.0);
    formulasToEvaluations.put("OR(0, 0.5)", 0.0);
    formulasToEvaluations.put("1 OR 1", 1.0);
    formulasToEvaluations.put("XOR(1, 1)", 0.0);
    formulasToEvaluations.put("XOR(1, 0)", 1.0);
    formulasToEvaluations.put("XOR(0, 1)", 1.0);
    formulasToEvaluations.put("XOR(0, 0)", 0.0);
    formulasToEvaluations.put("XOR(0, 0.5)", 0.0);
    formulasToEvaluations.put("1 XOR 1", 0.0);
    formulasToEvaluations.put("1 AND 0 OR 1 XOR 1 + 1", 1.0);
    formulasToEvaluations.put("1 < 2", 1.0);
    formulasToEvaluations.put("1 < 0", 0.0);
    formulasToEvaluations.put("1 < 1", 0.0);
    formulasToEvaluations.put("1 <= 2", 1.0);
    formulasToEvaluations.put("1 <= 1", 1.0);
    formulasToEvaluations.put("1 <= -1", 0.0);
    formulasToEvaluations.put("1 = 2", 0.0);
    formulasToEvaluations.put("1 = 1", 1.0);
    formulasToEvaluations.put("1 = -1", 0.0);
    formulasToEvaluations.put("1 > 2", 0.0);
    formulasToEvaluations.put("1 > 1", 0.0);
    formulasToEvaluations.put("1 > -1", 1.0);
    formulasToEvaluations.put("1 >= 2", 0.0);
    formulasToEvaluations.put("1 >= 1", 1.0);
    formulasToEvaluations.put("1 >= -1", 1.0);
    formulasToEvaluations.put("IF(1 > 2, 1, 2)", 2.0);
    formulasToEvaluations.put("IF(1 < 2, 1, 2)", 1.0);
    formulasToEvaluations.put("IF(1 < 2 AND 3 < 4, 1, 2)", 1.0);
    ExpressionParser parser = new ExpressionParser();
    try {
        for (String formula : formulasToEvaluations.keySet()) {
            Expression expression = parser.parseExpression(formula);
            double value = expression.evaluate(context);
            assertEquals(formulasToEvaluations.get(formula), value, 0.01);
        }
    } catch (ParseException e) {
        e.printStackTrace();
    }
}
Also used : Context(edu.cmu.tetrad.calculator.expression.Context) Expression(edu.cmu.tetrad.calculator.expression.Expression) ExpressionParser(edu.cmu.tetrad.calculator.parser.ExpressionParser) ParseException(java.text.ParseException) Test(org.junit.Test)

Aggregations

Context (edu.cmu.tetrad.calculator.expression.Context)10 Expression (edu.cmu.tetrad.calculator.expression.Expression)10 ExpressionParser (edu.cmu.tetrad.calculator.parser.ExpressionParser)3 ParseException (java.text.ParseException)3 Test (org.junit.Test)2 MultivariateFunction (org.apache.commons.math3.analysis.MultivariateFunction)1 InitialGuess (org.apache.commons.math3.optim.InitialGuess)1 MaxEval (org.apache.commons.math3.optim.MaxEval)1 PointValuePair (org.apache.commons.math3.optim.PointValuePair)1 MultivariateOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer)1 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)1 PowellOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer)1