use of com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable in project aic-praise by aic-sri-international.
the class ExpressionBayesianModel method convertToAnExpressionBasedModelAfterLearning.
/**
* Converting the learned ExpressionBayesianModel into an ExpressionBasedModel, ready for inferences.
* Some assumptions: there are no constants in the initial expressions of the nodes (would result in errors when learning the parameters)
*
* @param mapFromCategoricalTypeNameToSizeString (the user specifies the categorical types used, usually an empty map)
* @param additionalTypes (the user specifies the additional types used, usually an empty list)
*
* @return the equivalent ExpressionBasedModel for inference
*/
public ExpressionBasedModel convertToAnExpressionBasedModelAfterLearning(Map<String, String> mapFromCategoricalTypeNameToSizeString, Collection<Type> additionalTypes) {
// The nodes of this Bayesian model
List<? extends Expression> factors = this.getNodes();
// The definitions of variables
Map<String, String> mapFromRandomVariableNameToTypeName = map();
for (ExpressionBayesianNode node : this.getNodes()) {
ExpressionVariable nodeVariable = node.getChildVariable();
Context context = node.getContext();
Type type = context.getTypeOfRegisteredSymbol(nodeVariable);
mapFromRandomVariableNameToTypeName.put(nodeVariable.toString(), type.toString());
}
// The definitions of non-uniquely named constants
// (we assume that there are no constants in the initial expressions of the nodes - would result in errors when learning)
Map<String, String> mapFromNonUniquelyNamedConstantNameToTypeName = map();
// The definitions of uniquely named constants
// (similar comment as above)
Map<String, String> mapFromUniquelyNamedConstantNameToTypeName = map();
// ExpressionBayesianModels are known to be Bayesian models
boolean isBayesianNetwork = true;
ExpressionBasedModel convertedModel = new DefaultExpressionBasedModel(factors, mapFromRandomVariableNameToTypeName, mapFromNonUniquelyNamedConstantNameToTypeName, mapFromUniquelyNamedConstantNameToTypeName, mapFromCategoricalTypeNameToSizeString, additionalTypes, isBayesianNetwork);
return convertedModel;
}
use of com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable in project aic-praise by aic-sri-international.
the class TestCases method expressionFactorIsingModel.
private static ArrayList<ExpressionFactor> expressionFactorIsingModel(int gridSize, BiFunction<Pair<Integer, Integer>, Pair<Integer, Integer>, ArrayList<Double>> pairwiseFactorentries, Function<Pair<Integer, Integer>, ArrayList<Double>> singleVariableFactorEntries) {
ArrayList<ArrayList<ExpressionVariable>> variables = new ArrayList<>();
Context context = new TrueContext(new CommonTheory());
for (int i = 0; i < gridSize; i++) {
ArrayList<ExpressionVariable> col = new ArrayList<>();
variables.add(col);
for (int j = 0; j < gridSize; j++) {
ExpressionVariable v = new DefaultExpressionVariable(makeSymbol("A_" + i + "_" + j));
col.add(j, v);
makeSymbol("A_" + i + "_" + j);
context = context.extendWithSymbolsAndTypes(v, parse("Boolean"));
}
}
ArrayList<ExpressionFactor> result = new ArrayList<>();
for (int i = 0; i < gridSize - 1; i++) {
for (int j = 0; j < gridSize; j++) {
// Expression entry1 = Expressions.makeSymbol(randomGenerator.nextInt(10000)*1./10000.);
ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i + 1, j));
Expression entry1 = makeSymbol(entryList.get(0));
Expression entry2 = makeSymbol(entryList.get(1));
Expression entry3 = makeSymbol(entryList.get(2));
Expression entry4 = makeSymbol(entryList.get(3));
Expression X = variables.get(i).get(j);
Expression Y = variables.get(i + 1).get(j);
result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
}
}
for (int i = 0; i < gridSize; i++) {
for (int j = 0; j < gridSize - 1; j++) {
ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i, j + 1));
Expression entry1 = makeSymbol(entryList.get(0));
Expression entry2 = makeSymbol(entryList.get(1));
Expression entry3 = makeSymbol(entryList.get(2));
Expression entry4 = makeSymbol(entryList.get(3));
Expression X = variables.get(i).get(j);
Expression Y = variables.get(i).get(j + 1);
result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
}
}
if (!(singleVariableFactorEntries.apply(new Pair<>(0, 0)) == null)) {
for (int i = 0; i < gridSize; i++) {
for (int j = 0; j < gridSize; j++) {
ArrayList<Double> entryList = singleVariableFactorEntries.apply(new Pair<>(i, j));
Expression entry1 = makeSymbol(entryList.get(0));
Expression entry2 = makeSymbol(entryList.get(1));
Expression X = variables.get(i).get(j);
result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), entry1, entry2), context));
}
}
}
return result;
}
Aggregations