use of com.sri.ai.grinder.application.CommonTheory in project aic-praise by aic-sri-international.
the class ParameterEstimationForExpressionBasedModel method optimize.
/**
* Main method to optimize the parameters of the model given the queries and
* evidences when the model is Expression based.
*/
public HashMap<Expression, Double> optimize(double[] startPoint) {
Theory theory = new CommonTheory();
Context context = new TrueContext(theory);
Expression marginalFunctionLog = buildListOfMarginals();
System.out.println(marginalFunctionLog);
Set<Expression> listOfVariables = freeVariables(marginalFunctionLog, context);
if (listOfVariables.size() != startPoint.length) {
throw new Error("Length of double[] startPoint doesn't match the number of variables to optimize with evidence");
}
double[] argmax = returnArgopt(startPoint, MAXIMIZE, marginalFunctionLog);
HashMap<Expression, Double> result = buildHashMapContainingResult(listOfVariables, argmax);
return result;
}
use of com.sri.ai.grinder.application.CommonTheory in project aic-praise by aic-sri-international.
the class ExpressionBayesianNode method main.
public static void main(String[] args) {
Theory theory = new CommonTheory();
Context context = new TrueContext(theory);
// Only one child and one parent, 2 parameters (Param1 and Param2)
ExpressionVariable child = new DefaultExpressionVariable(parse("Child"));
ExpressionVariable parent = new DefaultExpressionVariable(parse("Parent"));
Expression param1 = parse("Param1");
Expression param2 = parse("Param2");
Expression param3 = parse("Param3");
context = context.extendWithSymbolsAndTypes("Child", "1..5", "Parent", "1..5", "Param1", "Real", "Param2", "Real", "Param3", "Real");
LinkedHashSet<Expression> parameters = Util.set(param1, param2);
parameters.add(param3);
// Expression E = parse("if Child < 5 then Param1 else Param2");
Expression E = parse("if Parent != 5 then Param1 else Param2");
// Expression E = parse("if Parent != 5 then if Child < 5 then Param1 else Param2 else Param3");
// Expression E = parse("if Parent != 5 then if Child < Parent then Param1 else Param2 else Param3"); // partial intersection
// Expression E = parse("if Child > 3 then Param1 else Param2");
println("E = " + E + "\n");
ExpressionBayesianNode parentNode = new ExpressionBayesianNode(E, context, parent, list(), parameters);
ExpressionBayesianNode childNode = new ExpressionBayesianNode(E, context, child, list(parent), parameters);
println("Families = " + childNode.getFamilies());
childNode.setInitialCountsForAllPossibleChildAndParentsAssignments();
// Incrementing from datapoints
LinkedList<Expression> childAndParentsValues = list(parse("1"), parse("1"));
int nIncrements = 0;
for (int i = 1; i <= nIncrements; i++) {
childNode.incrementCountForChildAndParentsAssignment(childAndParentsValues);
}
childNode.normalizeParameters();
println("\nnew E = " + childNode.getInnerExpression());
println("\nEnd of Program");
}
use of com.sri.ai.grinder.application.CommonTheory in project aic-praise by aic-sri-international.
the class TestCases method expressionFactorIsingModel.
private static ArrayList<ExpressionFactor> expressionFactorIsingModel(int gridSize, BiFunction<Pair<Integer, Integer>, Pair<Integer, Integer>, ArrayList<Double>> pairwiseFactorentries, Function<Pair<Integer, Integer>, ArrayList<Double>> singleVariableFactorEntries) {
ArrayList<ArrayList<ExpressionVariable>> variables = new ArrayList<>();
Context context = new TrueContext(new CommonTheory());
for (int i = 0; i < gridSize; i++) {
ArrayList<ExpressionVariable> col = new ArrayList<>();
variables.add(col);
for (int j = 0; j < gridSize; j++) {
ExpressionVariable v = new DefaultExpressionVariable(makeSymbol("A_" + i + "_" + j));
col.add(j, v);
makeSymbol("A_" + i + "_" + j);
context = context.extendWithSymbolsAndTypes(v, parse("Boolean"));
}
}
ArrayList<ExpressionFactor> result = new ArrayList<>();
for (int i = 0; i < gridSize - 1; i++) {
for (int j = 0; j < gridSize; j++) {
// Expression entry1 = Expressions.makeSymbol(randomGenerator.nextInt(10000)*1./10000.);
ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i + 1, j));
Expression entry1 = makeSymbol(entryList.get(0));
Expression entry2 = makeSymbol(entryList.get(1));
Expression entry3 = makeSymbol(entryList.get(2));
Expression entry4 = makeSymbol(entryList.get(3));
Expression X = variables.get(i).get(j);
Expression Y = variables.get(i + 1).get(j);
result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
}
}
for (int i = 0; i < gridSize; i++) {
for (int j = 0; j < gridSize - 1; j++) {
ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i, j + 1));
Expression entry1 = makeSymbol(entryList.get(0));
Expression entry2 = makeSymbol(entryList.get(1));
Expression entry3 = makeSymbol(entryList.get(2));
Expression entry4 = makeSymbol(entryList.get(3));
Expression X = variables.get(i).get(j);
Expression Y = variables.get(i).get(j + 1);
result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
}
}
if (!(singleVariableFactorEntries.apply(new Pair<>(0, 0)) == null)) {
for (int i = 0; i < gridSize; i++) {
for (int j = 0; j < gridSize; j++) {
ArrayList<Double> entryList = singleVariableFactorEntries.apply(new Pair<>(i, j));
Expression entry1 = makeSymbol(entryList.get(0));
Expression entry2 = makeSymbol(entryList.get(1));
Expression X = variables.get(i).get(j);
result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), entry1, entry2), context));
}
}
}
return result;
}
Aggregations