Search in sources :

Example 56 with CommonTheory

use of com.sri.ai.grinder.application.CommonTheory in project aic-praise by aic-sri-international.

the class ParameterEstimationForExpressionBasedModel method optimize.

/**
 * Main method to optimize the parameters of the model given the queries and
 * evidences when the model is Expression based.
 */
public HashMap<Expression, Double> optimize(double[] startPoint) {
    Theory theory = new CommonTheory();
    Context context = new TrueContext(theory);
    Expression marginalFunctionLog = buildListOfMarginals();
    System.out.println(marginalFunctionLog);
    Set<Expression> listOfVariables = freeVariables(marginalFunctionLog, context);
    if (listOfVariables.size() != startPoint.length) {
        throw new Error("Length of double[] startPoint doesn't match the number of variables to optimize with evidence");
    }
    double[] argmax = returnArgopt(startPoint, MAXIMIZE, marginalFunctionLog);
    HashMap<Expression, Double> result = buildHashMapContainingResult(listOfVariables, argmax);
    return result;
}
Also used : CommonTheory(com.sri.ai.grinder.application.CommonTheory) TrueContext(com.sri.ai.grinder.core.TrueContext) Context(com.sri.ai.grinder.api.Context) CommonTheory(com.sri.ai.grinder.application.CommonTheory) Theory(com.sri.ai.grinder.api.Theory) Expression(com.sri.ai.expresso.api.Expression) TrueContext(com.sri.ai.grinder.core.TrueContext)

Example 57 with CommonTheory

use of com.sri.ai.grinder.application.CommonTheory in project aic-praise by aic-sri-international.

the class ExpressionBayesianNode method main.

public static void main(String[] args) {
    Theory theory = new CommonTheory();
    Context context = new TrueContext(theory);
    // Only one child and one parent, 2 parameters (Param1 and Param2)
    ExpressionVariable child = new DefaultExpressionVariable(parse("Child"));
    ExpressionVariable parent = new DefaultExpressionVariable(parse("Parent"));
    Expression param1 = parse("Param1");
    Expression param2 = parse("Param2");
    Expression param3 = parse("Param3");
    context = context.extendWithSymbolsAndTypes("Child", "1..5", "Parent", "1..5", "Param1", "Real", "Param2", "Real", "Param3", "Real");
    LinkedHashSet<Expression> parameters = Util.set(param1, param2);
    parameters.add(param3);
    // Expression E = parse("if Child < 5 then Param1 else Param2");
    Expression E = parse("if Parent != 5 then Param1 else Param2");
    // Expression E = parse("if Parent != 5 then if Child < 5 then Param1 else Param2 else Param3");
    // Expression E = parse("if Parent != 5 then if Child < Parent then Param1 else Param2 else Param3"); // partial intersection
    // Expression E = parse("if Child > 3 then Param1 else Param2");
    println("E = " + E + "\n");
    ExpressionBayesianNode parentNode = new ExpressionBayesianNode(E, context, parent, list(), parameters);
    ExpressionBayesianNode childNode = new ExpressionBayesianNode(E, context, child, list(parent), parameters);
    println("Families = " + childNode.getFamilies());
    childNode.setInitialCountsForAllPossibleChildAndParentsAssignments();
    // Incrementing from datapoints
    LinkedList<Expression> childAndParentsValues = list(parse("1"), parse("1"));
    int nIncrements = 0;
    for (int i = 1; i <= nIncrements; i++) {
        childNode.incrementCountForChildAndParentsAssignment(childAndParentsValues);
    }
    childNode.normalizeParameters();
    println("\nnew E = " + childNode.getInnerExpression());
    println("\nEnd of Program");
}
Also used : CommonTheory(com.sri.ai.grinder.application.CommonTheory) TrueContext(com.sri.ai.grinder.core.TrueContext) Context(com.sri.ai.grinder.api.Context) DefaultExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable) CommonTheory(com.sri.ai.grinder.application.CommonTheory) Theory(com.sri.ai.grinder.api.Theory) Expression(com.sri.ai.expresso.api.Expression) DefaultExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable) ExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable) TrueContext(com.sri.ai.grinder.core.TrueContext)

Example 58 with CommonTheory

use of com.sri.ai.grinder.application.CommonTheory in project aic-praise by aic-sri-international.

the class TestCases method expressionFactorIsingModel.

private static ArrayList<ExpressionFactor> expressionFactorIsingModel(int gridSize, BiFunction<Pair<Integer, Integer>, Pair<Integer, Integer>, ArrayList<Double>> pairwiseFactorentries, Function<Pair<Integer, Integer>, ArrayList<Double>> singleVariableFactorEntries) {
    ArrayList<ArrayList<ExpressionVariable>> variables = new ArrayList<>();
    Context context = new TrueContext(new CommonTheory());
    for (int i = 0; i < gridSize; i++) {
        ArrayList<ExpressionVariable> col = new ArrayList<>();
        variables.add(col);
        for (int j = 0; j < gridSize; j++) {
            ExpressionVariable v = new DefaultExpressionVariable(makeSymbol("A_" + i + "_" + j));
            col.add(j, v);
            makeSymbol("A_" + i + "_" + j);
            context = context.extendWithSymbolsAndTypes(v, parse("Boolean"));
        }
    }
    ArrayList<ExpressionFactor> result = new ArrayList<>();
    for (int i = 0; i < gridSize - 1; i++) {
        for (int j = 0; j < gridSize; j++) {
            // Expression entry1 = Expressions.makeSymbol(randomGenerator.nextInt(10000)*1./10000.);
            ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i + 1, j));
            Expression entry1 = makeSymbol(entryList.get(0));
            Expression entry2 = makeSymbol(entryList.get(1));
            Expression entry3 = makeSymbol(entryList.get(2));
            Expression entry4 = makeSymbol(entryList.get(3));
            Expression X = variables.get(i).get(j);
            Expression Y = variables.get(i + 1).get(j);
            result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
        }
    }
    for (int i = 0; i < gridSize; i++) {
        for (int j = 0; j < gridSize - 1; j++) {
            ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i, j + 1));
            Expression entry1 = makeSymbol(entryList.get(0));
            Expression entry2 = makeSymbol(entryList.get(1));
            Expression entry3 = makeSymbol(entryList.get(2));
            Expression entry4 = makeSymbol(entryList.get(3));
            Expression X = variables.get(i).get(j);
            Expression Y = variables.get(i).get(j + 1);
            result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
        }
    }
    if (!(singleVariableFactorEntries.apply(new Pair<>(0, 0)) == null)) {
        for (int i = 0; i < gridSize; i++) {
            for (int j = 0; j < gridSize; j++) {
                ArrayList<Double> entryList = singleVariableFactorEntries.apply(new Pair<>(i, j));
                Expression entry1 = makeSymbol(entryList.get(0));
                Expression entry2 = makeSymbol(entryList.get(1));
                Expression X = variables.get(i).get(j);
                result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), entry1, entry2), context));
            }
        }
    }
    return result;
}
Also used : TrueContext(com.sri.ai.grinder.core.TrueContext) Context(com.sri.ai.grinder.api.Context) DefaultExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable) DefaultExpressionFactor(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionFactor) Util.mapIntoArrayList(com.sri.ai.util.Util.mapIntoArrayList) ArrayList(java.util.ArrayList) TrueContext(com.sri.ai.grinder.core.TrueContext) CommonTheory(com.sri.ai.grinder.application.CommonTheory) ExpressionFactor(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionFactor) DefaultExpressionFactor(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionFactor) Expression(com.sri.ai.expresso.api.Expression) DefaultExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable) ExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable) Pair(com.sri.ai.util.base.Pair)

Aggregations

Context (com.sri.ai.grinder.api.Context)58 CommonTheory (com.sri.ai.grinder.application.CommonTheory)58 TrueContext (com.sri.ai.grinder.core.TrueContext)58 Expression (com.sri.ai.expresso.api.Expression)55 Test (org.junit.Test)47 Theory (com.sri.ai.grinder.api.Theory)36 Factor (com.sri.ai.praise.core.representation.interfacebased.factor.api.Factor)13 ExpressionFactor (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionFactor)13 DefaultExpressionFactor (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionFactor)12 ConstantFactor (com.sri.ai.praise.core.representation.interfacebased.factor.core.ConstantFactor)11 CompoundTheory (com.sri.ai.grinder.theory.compound.CompoundTheory)9 DifferenceArithmeticTheory (com.sri.ai.grinder.theory.differencearithmetic.DifferenceArithmeticTheory)9 EqualityTheory (com.sri.ai.grinder.theory.equality.EqualityTheory)9 LinearRealArithmeticTheory (com.sri.ai.grinder.theory.linearrealarithmetic.LinearRealArithmeticTheory)9 PropositionalTheory (com.sri.ai.grinder.theory.propositional.PropositionalTheory)9 TupleTheory (com.sri.ai.grinder.theory.tuple.TupleTheory)9 DefaultExpressionVariable (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable)7 ExpressionVariable (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable)5 ExpressionFactorNetwork (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.ExpressionFactorNetwork)4 ArrayList (java.util.ArrayList)3