Search in sources :

Example 11 with ExpressionVariable

use of com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable in project aic-praise by aic-sri-international.

the class ExpressionBayesianModel method convertToAnExpressionBasedModelAfterLearning.

/**
 * Converting the learned ExpressionBayesianModel into an ExpressionBasedModel, ready for inferences.
 * Some assumptions: there are no constants in the initial expressions of the nodes (would result in errors when learning the parameters)
 *
 * @param mapFromCategoricalTypeNameToSizeString (the user specifies the categorical types used, usually an empty map)
 * @param additionalTypes (the user specifies the additional types used, usually an empty list)
 *
 * @return the equivalent ExpressionBasedModel for inference
 */
public ExpressionBasedModel convertToAnExpressionBasedModelAfterLearning(Map<String, String> mapFromCategoricalTypeNameToSizeString, Collection<Type> additionalTypes) {
    // The nodes of this Bayesian model
    List<? extends Expression> factors = this.getNodes();
    // The definitions of variables
    Map<String, String> mapFromRandomVariableNameToTypeName = map();
    for (ExpressionBayesianNode node : this.getNodes()) {
        ExpressionVariable nodeVariable = node.getChildVariable();
        Context context = node.getContext();
        Type type = context.getTypeOfRegisteredSymbol(nodeVariable);
        mapFromRandomVariableNameToTypeName.put(nodeVariable.toString(), type.toString());
    }
    // The definitions of non-uniquely named constants
    // (we assume that there are no constants in the initial expressions of the nodes - would result in errors when learning)
    Map<String, String> mapFromNonUniquelyNamedConstantNameToTypeName = map();
    // The definitions of uniquely named constants
    // (similar comment as above)
    Map<String, String> mapFromUniquelyNamedConstantNameToTypeName = map();
    // ExpressionBayesianModels are known to be Bayesian models
    boolean isBayesianNetwork = true;
    ExpressionBasedModel convertedModel = new DefaultExpressionBasedModel(factors, mapFromRandomVariableNameToTypeName, mapFromNonUniquelyNamedConstantNameToTypeName, mapFromUniquelyNamedConstantNameToTypeName, mapFromCategoricalTypeNameToSizeString, additionalTypes, isBayesianNetwork);
    return convertedModel;
}
Also used : Context(com.sri.ai.grinder.api.Context) Type(com.sri.ai.expresso.api.Type) DefaultExpressionBasedModel(com.sri.ai.praise.core.representation.classbased.expressionbased.core.DefaultExpressionBasedModel) DefaultExpressionBasedModel(com.sri.ai.praise.core.representation.classbased.expressionbased.core.DefaultExpressionBasedModel) ExpressionBasedModel(com.sri.ai.praise.core.representation.classbased.expressionbased.api.ExpressionBasedModel) ExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable)

Example 12 with ExpressionVariable

use of com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable in project aic-praise by aic-sri-international.

the class TestCases method expressionFactorIsingModel.

private static ArrayList<ExpressionFactor> expressionFactorIsingModel(int gridSize, BiFunction<Pair<Integer, Integer>, Pair<Integer, Integer>, ArrayList<Double>> pairwiseFactorentries, Function<Pair<Integer, Integer>, ArrayList<Double>> singleVariableFactorEntries) {
    ArrayList<ArrayList<ExpressionVariable>> variables = new ArrayList<>();
    Context context = new TrueContext(new CommonTheory());
    for (int i = 0; i < gridSize; i++) {
        ArrayList<ExpressionVariable> col = new ArrayList<>();
        variables.add(col);
        for (int j = 0; j < gridSize; j++) {
            ExpressionVariable v = new DefaultExpressionVariable(makeSymbol("A_" + i + "_" + j));
            col.add(j, v);
            makeSymbol("A_" + i + "_" + j);
            context = context.extendWithSymbolsAndTypes(v, parse("Boolean"));
        }
    }
    ArrayList<ExpressionFactor> result = new ArrayList<>();
    for (int i = 0; i < gridSize - 1; i++) {
        for (int j = 0; j < gridSize; j++) {
            // Expression entry1 = Expressions.makeSymbol(randomGenerator.nextInt(10000)*1./10000.);
            ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i + 1, j));
            Expression entry1 = makeSymbol(entryList.get(0));
            Expression entry2 = makeSymbol(entryList.get(1));
            Expression entry3 = makeSymbol(entryList.get(2));
            Expression entry4 = makeSymbol(entryList.get(3));
            Expression X = variables.get(i).get(j);
            Expression Y = variables.get(i + 1).get(j);
            result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
        }
    }
    for (int i = 0; i < gridSize; i++) {
        for (int j = 0; j < gridSize - 1; j++) {
            ArrayList<Double> entryList = pairwiseFactorentries.apply(new Pair<>(i, j), new Pair<>(i, j + 1));
            Expression entry1 = makeSymbol(entryList.get(0));
            Expression entry2 = makeSymbol(entryList.get(1));
            Expression entry3 = makeSymbol(entryList.get(2));
            Expression entry4 = makeSymbol(entryList.get(3));
            Expression X = variables.get(i).get(j);
            Expression Y = variables.get(i).get(j + 1);
            result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry1, entry2), apply(IF_THEN_ELSE, apply(EQUAL, Y, makeSymbol("true")), entry3, entry4)), context));
        }
    }
    if (!(singleVariableFactorEntries.apply(new Pair<>(0, 0)) == null)) {
        for (int i = 0; i < gridSize; i++) {
            for (int j = 0; j < gridSize; j++) {
                ArrayList<Double> entryList = singleVariableFactorEntries.apply(new Pair<>(i, j));
                Expression entry1 = makeSymbol(entryList.get(0));
                Expression entry2 = makeSymbol(entryList.get(1));
                Expression X = variables.get(i).get(j);
                result.add(new DefaultExpressionFactor(apply(IF_THEN_ELSE, apply(EQUAL, X, makeSymbol("true")), entry1, entry2), context));
            }
        }
    }
    return result;
}
Also used : TrueContext(com.sri.ai.grinder.core.TrueContext) Context(com.sri.ai.grinder.api.Context) DefaultExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable) DefaultExpressionFactor(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionFactor) Util.mapIntoArrayList(com.sri.ai.util.Util.mapIntoArrayList) ArrayList(java.util.ArrayList) TrueContext(com.sri.ai.grinder.core.TrueContext) CommonTheory(com.sri.ai.grinder.application.CommonTheory) ExpressionFactor(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionFactor) DefaultExpressionFactor(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionFactor) Expression(com.sri.ai.expresso.api.Expression) DefaultExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable) ExpressionVariable(com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable) Pair(com.sri.ai.util.base.Pair)

Aggregations

ExpressionVariable (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionVariable)12 DefaultExpressionVariable (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionVariable)9 Expression (com.sri.ai.expresso.api.Expression)6 Context (com.sri.ai.grinder.api.Context)6 CommonTheory (com.sri.ai.grinder.application.CommonTheory)5 TrueContext (com.sri.ai.grinder.core.TrueContext)5 Type (com.sri.ai.expresso.api.Type)3 ExpressionFactor (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.api.ExpressionFactor)3 DefaultExpressionFactor (com.sri.ai.praise.core.representation.interfacebased.factor.core.expression.core.DefaultExpressionFactor)3 ArrayList (java.util.ArrayList)3 Theory (com.sri.ai.grinder.api.Theory)2 Factor (com.sri.ai.praise.core.representation.interfacebased.factor.api.Factor)2 ConstantFactor (com.sri.ai.praise.core.representation.interfacebased.factor.core.ConstantFactor)2 DefaultDatapoint (com.sri.ai.praise.learning.parameterlearning.representation.dataset.DefaultDatapoint)2 DefaultDataset (com.sri.ai.praise.learning.parameterlearning.representation.dataset.DefaultDataset)2 Test (org.junit.Test)2 DefaultExistentiallyQuantifiedFormula (com.sri.ai.expresso.core.DefaultExistentiallyQuantifiedFormula)1 DefaultIntensionalMultiSet (com.sri.ai.expresso.core.DefaultIntensionalMultiSet)1 IntegerExpressoType (com.sri.ai.expresso.type.IntegerExpressoType)1 RealExpressoType (com.sri.ai.expresso.type.RealExpressoType)1