Search in sources :

Example 1 with CartesianProductEnumeration

use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.

the class XFormMarkovToBayes method sumOutChildAndCreateCPT.

private static Pair<FactorTable, ConditionalProbabilityTable> sumOutChildAndCreateCPT(MarkovNetwork markov, Integer c, FactorTable cFactor) {
    List<Integer> parentVarIdxs = new ArrayList<>(cFactor.getVariableIndexes());
    parentVarIdxs.remove(c);
    int cCardinality = markov.cardinality(c);
    int cIdx = cFactor.getVariableIndexes().indexOf(c);
    FactorTable summedOut = null;
    FunctionTable cptTable = null;
    if (parentVarIdxs.size() == 0) {
        // No summed out factor as no remaining elements but still need to ensure the cpt values are normalized			
        int idx = 0;
        List<Double> factorEntries = cFactor.getTable().getEntries();
        List<Double> cptEntries = new ArrayList<>(factorEntries.size());
        while (idx < factorEntries.size()) {
            double total = 0;
            for (int i = 0; i < cCardinality; i++) {
                total += factorEntries.get(idx + i);
            }
            for (int i = 0; i < cCardinality; i++) {
                cptEntries.add(factorEntries.get(idx + i) / total);
            }
            idx += cCardinality;
        }
        cptTable = new FunctionTable(Arrays.asList(cCardinality), cptEntries);
    } else {
        List<Integer> cptCardinalities = new ArrayList<>();
        List<Integer> parentVarCardinalities = new ArrayList<>();
        Map<Integer, Integer> parentCardIdxToCFactorIdx = new LinkedHashMap<>();
        int cFactorCardIdx = 0;
        for (Integer varIdx : cFactor.getVariableIndexes()) {
            if (!varIdx.equals(c)) {
                int card = markov.cardinality(varIdx);
                cptCardinalities.add(card);
                parentVarCardinalities.add(card);
                parentCardIdxToCFactorIdx.put(parentCardIdxToCFactorIdx.size(), cFactorCardIdx);
            }
            cFactorCardIdx++;
        }
        // The child index will be placed at the end of the table by convention
        cptCardinalities.add(markov.cardinality(c));
        List<Double> cptEntries = new ArrayList<>(FunctionTable.numEntriesFor(cptCardinalities));
        List<Double> parentSummedOutEntries = new ArrayList<>(FunctionTable.numEntriesFor(parentVarCardinalities));
        Map<Integer, Integer> assignmentMap = new LinkedHashMap<>();
        CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(parentVarCardinalities));
        while (cpe.hasMoreElements()) {
            List<Integer> parentValues = cpe.nextElement();
            assignmentMap.clear();
            for (int i = 0; i < parentValues.size(); i++) {
                assignmentMap.put(parentCardIdxToCFactorIdx.get(i), parentValues.get(i));
            }
            Double sum = cFactor.getTable().valueFor(assignmentMap);
            parentSummedOutEntries.add(sum);
            for (int i = 0; i < cCardinality; i++) {
                assignmentMap.put(cIdx, i);
                if (sum.equals(0.0)) {
                    // TODO - is this approach correct?						
                    // NOTE: This prevents invalid models being generated by assigning an impossibly small probability to an event that should never occur
                    cptEntries.add(Double.MIN_NORMAL);
                } else {
                    cptEntries.add(cFactor.getTable().valueFor(assignmentMap) / sum);
                }
            }
        }
        summedOut = new FactorTable(parentVarIdxs, new FunctionTable(parentVarCardinalities, parentSummedOutEntries));
        cptTable = new FunctionTable(cptCardinalities, cptEntries);
    }
    ConditionalProbabilityTable cpt = new ConditionalProbabilityTable(parentVarIdxs, c, cptTable);
    Pair<FactorTable, ConditionalProbabilityTable> result = new Pair<>(summedOut, cpt);
    return result;
}
Also used : ConditionalProbabilityTable(com.sri.ai.praise.lang.grounded.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap) FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) FactorTable(com.sri.ai.praise.lang.grounded.markov.FactorTable) Pair(com.sri.ai.util.base.Pair)

Example 2 with CartesianProductEnumeration

use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.

the class FunctionTable method valueFor.

public Double valueFor(Map<Integer, Integer> assignmentMap) {
    double result = 0;
    // If have all assignments then get the entry straight off as opposed to summing them
    if (assignmentMap.size() == varCardinalities.size()) {
        result = entryFor(IntStream.range(0, varCardinalities.size()).boxed().map(i -> assignmentMap.get(i)).collect(Collectors.toList()));
    } else {
        // More than 1 entry value needs to be summed up
        List<List<Integer>> possibleValues = new ArrayList<>();
        for (int i = 0; i < varCardinalities.size(); i++) {
            Integer value = assignmentMap.get(i);
            if (value == null) {
                possibleValues.add(IntStream.range(0, varCardinalities.get(i)).boxed().collect(Collectors.toList()));
            } else {
                possibleValues.add(Arrays.asList(value));
            }
        }
        CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(possibleValues);
        while (cpe.hasMoreElements()) {
            result += entryFor(cpe.nextElement());
        }
    }
    return result;
}
Also used : IntStream(java.util.stream.IntStream) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) Arrays(java.util.Arrays) List(java.util.List) MixedRadixNumber(com.sri.ai.util.math.MixedRadixNumber) Map(java.util.Map) BigInteger(java.math.BigInteger) Collections(java.util.Collections) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) Beta(com.google.common.annotations.Beta) BigInteger(java.math.BigInteger) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration)

Example 3 with CartesianProductEnumeration

use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.

the class XFormMarkovToBayes method newMegaFactor.

private static FactorTable newMegaFactor(MarkovNetwork markov, Set<Integer> varIdxs, List<FactorTable> factorsContainingC) {
    List<Integer> variableIndexes = new ArrayList<>(varIdxs);
    Map<Integer, Integer> varIdxToOffset = new LinkedHashMap<>();
    List<Integer> varCardinalities = new ArrayList<>();
    for (Integer varIdx : variableIndexes) {
        varIdxToOffset.put(varIdx, varIdxToOffset.size());
        varCardinalities.add(markov.cardinality(varIdx));
    }
    List<Double> entries = new ArrayList<>(FunctionTable.numEntriesFor(varCardinalities));
    List<Integer> varValues = new ArrayList<>();
    CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(varCardinalities));
    while (cpe.hasMoreElements()) {
        List<Integer> assignments = cpe.nextElement();
        double product = 1;
        for (FactorTable ft : factorsContainingC) {
            varValues.clear();
            for (Integer varIdx : ft.getVariableIndexes()) {
                varValues.add(assignments.get(varIdxToOffset.get(varIdx)));
            }
            product *= ft.getTable().entryFor(varValues);
        }
        entries.add(product);
    }
    FactorTable result = new FactorTable(variableIndexes, new FunctionTable(varCardinalities, entries));
    return result;
}
Also used : FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) ArrayList(java.util.ArrayList) FactorTable(com.sri.ai.praise.lang.grounded.markov.FactorTable) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap)

Example 4 with CartesianProductEnumeration

use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.

the class TranslationOfTableToInequalities method constructGenericTableExpressionUsingInequalities.

/**
	 * Returns an {@link Expression} equivalent to a given {@link FunctionTable} but in the form of a decision tree
	 * (so hopefully more compact) using inequalities.
	 * @param functionTable
	 * @param solverListener if not null, invoked on solver used for compilation,
	 * before and after compilation is performed; returned solver from "before" invocation is used (it may be the same one used as argument, of course).
	 * @return
	 */
public static Expression constructGenericTableExpressionUsingInequalities(FunctionTable functionTable, Function<MultiIndexQuantifierEliminator, MultiIndexQuantifierEliminator> solverListener) {
    // the strategy in this method is the following:
    // we collect all the contiguous indices sub-sets of the function table sharing their function value.
    // They are kept in a map from each value to a list of indices sub-sets with that value.
    //
    // Then, we sort these groups of indices sub-sets by the sum of their sizes (number of entries), from smallest to largest.
    // This will help us later to create an expression that tests for the largest groups first.
    //
    // Finally, we create an if-then-else expression, starting from the leaf (least common value).
    // For each group of indices sub-sets with the same value, we obtain an inequalities expression describing
    // the conditions for a variable assignment to be in that indices sub-set of the function table.
    // Each portion generates a conjunction, and the group of portions generates a disjunction.
    //
    // The resulting if-then-else expression is linearly organized (only else clauses have nested if-then-else expressions).
    // A more balanced (and thus efficient) representation is obtained by compiling it using SGDPLL(T).
    Map<Double, List<FunctionTableIndicesSubSet>> functionValuesAndCorrespondingIndicesSubSet = map();
    Double currentSubSetFunctionValueIfAny = null;
    List<Integer> firstIndicesOfCurrentSubSetIfAny = null;
    List<Integer> previousIndices = null;
    List<Integer> indices = null;
    CartesianProductEnumeration<Integer> cartesianProduct = new CartesianProductEnumeration<>(UAIUtil.cardinalityValues(functionTable));
    while (cartesianProduct.hasMoreElements()) {
        previousIndices = indices;
        indices = new ArrayList<>(cartesianProduct.nextElement());
        Double functionValue = Math.round(functionTable.entryFor(indices) * 100) / 100.0;
        boolean hitNewFunctionValue = currentSubSetFunctionValueIfAny == null || !functionValue.equals(currentSubSetFunctionValueIfAny);
        if (hitNewFunctionValue) {
            storeIndicesSubSetOnAllVariables(functionTable, firstIndicesOfCurrentSubSetIfAny, previousIndices, currentSubSetFunctionValueIfAny, functionValuesAndCorrespondingIndicesSubSet);
            // get information for next indices sub-set
            currentSubSetFunctionValueIfAny = functionValue;
            firstIndicesOfCurrentSubSetIfAny = indices;
        }
    }
    previousIndices = indices;
    storeIndicesSubSetOnAllVariables(functionTable, firstIndicesOfCurrentSubSetIfAny, previousIndices, currentSubSetFunctionValueIfAny, functionValuesAndCorrespondingIndicesSubSet);
    // we sort (by using TreeMap) lists of indices sub-set with the same function value from those with smaller to greater sizes,
    // and form the final expression backwards, thus prioritizing larger sub-sets
    // whose conditions will be more often satisfied and leading to greater simplifications during inference.
    List<Pair<BigInteger, List<FunctionTableIndicesSubSet>>> listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue = new ArrayList<>(functionValuesAndCorrespondingIndicesSubSet.size());
    for (Map.Entry<Double, List<FunctionTableIndicesSubSet>> functionValueAndIndicesSubSet : functionValuesAndCorrespondingIndicesSubSet.entrySet()) {
        List<FunctionTableIndicesSubSet> indicesSubSetsWithSameFunctionValue = functionValueAndIndicesSubSet.getValue();
        BigInteger sumOfSizes = BigInteger.ZERO;
        for (FunctionTableIndicesSubSet indicesSubSet : indicesSubSetsWithSameFunctionValue) {
            sumOfSizes = sumOfSizes.add(indicesSubSet.size());
        }
        listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue.add(Pair.make(sumOfSizes, indicesSubSetsWithSameFunctionValue));
    }
    Collections.sort(listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue, (Comparator<? super Pair<BigInteger, List<FunctionTableIndicesSubSet>>>) (p1, p2) -> p1.first.compareTo(p2.first));
    List<List<FunctionTableIndicesSubSet>> listsOfIndicesSubSetsWithSameFunctionValue = mapIntoList(listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue, p -> p.second);
    Iterator<List<FunctionTableIndicesSubSet>> listsOfIndicesSubSetsWithSameFunctionValueIterator = listsOfIndicesSubSetsWithSameFunctionValue.iterator();
    List<FunctionTableIndicesSubSet> firstListOfIndicesSubSets = listsOfIndicesSubSetsWithSameFunctionValueIterator.next();
    Double valueOfFirstListOfIndicesSubSets = getFirstOrNull(firstListOfIndicesSubSets).getFunctionValue();
    Expression currentExpression = makeSymbol(valueOfFirstListOfIndicesSubSets);
    while (listsOfIndicesSubSetsWithSameFunctionValueIterator.hasNext()) {
        List<FunctionTableIndicesSubSet> indicesSubSetsWithSameFunctionValue = listsOfIndicesSubSetsWithSameFunctionValueIterator.next();
        Expression functionValueOfIndicesSubSetsWithSameFunctionValue = makeSymbol(getFirstOrNull(indicesSubSetsWithSameFunctionValue).getFunctionValue());
        Expression conditionForThisFunctionValue = Or.make(mapIntoList(indicesSubSetsWithSameFunctionValue, TranslationOfTableToInequalities::getInequalitiesExpressionForFunctionTableIndicesSubSet));
        currentExpression = IfThenElse.make(conditionForThisFunctionValue, functionValueOfIndicesSubSetsWithSameFunctionValue, currentExpression);
    }
    return currentExpression;
}
Also used : LESS_THAN_OR_EQUAL_TO(com.sri.ai.grinder.sgdpllt.library.FunctorConstants.LESS_THAN_OR_EQUAL_TO) ONE(java.math.BigInteger.ONE) FALSE(com.sri.ai.expresso.helper.Expressions.FALSE) MultiIndexQuantifierEliminator(com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator) Expression(com.sri.ai.expresso.api.Expression) Function(java.util.function.Function) Util.forAll(com.sri.ai.util.Util.forAll) ArrayList(java.util.ArrayList) Util.map(com.sri.ai.util.Util.map) Expressions.apply(com.sri.ai.expresso.helper.Expressions.apply) Map(java.util.Map) And(com.sri.ai.grinder.sgdpllt.library.boole.And) IntegerIterator(com.sri.ai.util.collect.IntegerIterator) BigInteger(java.math.BigInteger) Pair(com.sri.ai.util.base.Pair) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) Or(com.sri.ai.grinder.sgdpllt.library.boole.Or) Iterator(java.util.Iterator) MixedRadixNumber(com.sri.ai.util.math.MixedRadixNumber) EQUALITY(com.sri.ai.grinder.sgdpllt.library.FunctorConstants.EQUALITY) Util.mapIntoList(com.sri.ai.util.Util.mapIntoList) GREATER_THAN_OR_EQUAL_TO(com.sri.ai.grinder.sgdpllt.library.FunctorConstants.GREATER_THAN_OR_EQUAL_TO) IfThenElse(com.sri.ai.grinder.sgdpllt.library.controlflow.IfThenElse) Beta(com.google.common.annotations.Beta) Util.getFirstOrNull(com.sri.ai.util.Util.getFirstOrNull) List(java.util.List) Expressions.makeSymbol(com.sri.ai.expresso.helper.Expressions.makeSymbol) UAIUtil(com.sri.ai.praise.model.v1.imports.uai.UAIUtil) Util.arrayListFilledWith(com.sri.ai.util.Util.arrayListFilledWith) FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) Comparator(java.util.Comparator) Util.putInListValue(com.sri.ai.util.Util.putInListValue) Collections(java.util.Collections) TRUE(com.sri.ai.expresso.helper.Expressions.TRUE) ArrayList(java.util.ArrayList) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) BigInteger(java.math.BigInteger) Expression(com.sri.ai.expresso.api.Expression) BigInteger(java.math.BigInteger) ArrayList(java.util.ArrayList) Util.mapIntoList(com.sri.ai.util.Util.mapIntoList) List(java.util.List) Map(java.util.Map) Pair(com.sri.ai.util.base.Pair)

Example 5 with CartesianProductEnumeration

use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.

the class UAIUtil method constructGenericTableExpressionUsingEqualities.

/**
	 * Returns an {@link Expression} equivalent to a given {@link FunctionTable} but in the form of a decision tree
	 * (so hopefully more compact) using equalities.
	 * @param functionTable
	 * @param solverListener if not null, invoked on solver used for compilation, before and after compilation is performed; returned solver from "before" invocation is used (it may be the same one used as argument, of course).
	 * @return
	 */
public static Expression constructGenericTableExpressionUsingEqualities(FunctionTable functionTable, Function<MultiIndexQuantifierEliminator, MultiIndexQuantifierEliminator> solverListener) {
    StringBuilder table = new StringBuilder();
    CartesianProductEnumeration<Integer> cartesianProduct = new CartesianProductEnumeration<>(cardinalityValues(functionTable));
    int counter = 0;
    while (cartesianProduct.hasMoreElements()) {
        counter++;
        List<Integer> values = cartesianProduct.nextElement();
        Double entryValue = functionTable.entryFor(values);
        if (counter == cartesianProduct.size().intValue()) {
            // i.e. final value
            table.append(entryValue);
        } else {
            table.append("if ");
            for (int i = 0; i < values.size(); i++) {
                if (i > 0) {
                    table.append(" and ");
                }
                String value = genericConstantValueForVariable(values.get(i), i, functionTable.cardinality(i));
                if (value.equals("true")) {
                    table.append(genericVariableName(i));
                } else if (value.equals("false")) {
                    table.append("not " + genericVariableName(i));
                } else {
                    table.append(genericVariableName(i));
                    table.append(" = ");
                    table.append(value);
                }
            }
            table.append(" then ");
            table.append(entryValue);
            table.append(" else ");
        }
    }
    Expression inputExpression = Expressions.parse(table.toString());
    Function<Integer, Integer> cardinalityOfIthVariable = i -> functionTable.cardinality(i);
    Map<String, String> mapFromCategoricalTypeNameToSizeString = new LinkedHashMap<>();
    Map<String, String> mapFromVariableNameToTypeName = new LinkedHashMap<>();
    Map<String, String> mapFromUniquelyNamedConstantToTypeName = new LinkedHashMap<>();
    for (int i = 0; i < functionTable.numberVariables(); i++) {
        String typeName = genericTypeNameForVariable(i, cardinalityOfIthVariable.apply(i));
        mapFromCategoricalTypeNameToSizeString.put(typeName, "" + cardinalityOfIthVariable.apply(i));
        mapFromVariableNameToTypeName.put(genericVariableName(i), typeName);
        for (int j = 0; j != functionTable.cardinality(i); j++) {
            String jThConstant = genericConstantValueForVariable(j, i, functionTable.cardinality(i));
            mapFromUniquelyNamedConstantToTypeName.put(jThConstant, typeName);
        }
    }
    com.sri.ai.grinder.sgdpllt.api.Theory theory = new EqualityTheory(true, true);
    Expression result = Compilation.compile(inputExpression, theory, mapFromVariableNameToTypeName, mapFromUniquelyNamedConstantToTypeName, mapFromCategoricalTypeNameToSizeString, list(), solverListener);
    return result;
}
Also used : CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) MultiIndexQuantifierEliminator(com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator) SyntacticSubstitute(com.sri.ai.grinder.sgdpllt.library.SyntacticSubstitute) Compilation(com.sri.ai.grinder.sgdpllt.application.Compilation) Expressions(com.sri.ai.expresso.helper.Expressions) Util.list(com.sri.ai.util.Util.list) IOException(java.io.IOException) Expression(com.sri.ai.expresso.api.Expression) EqualityTheory(com.sri.ai.grinder.sgdpllt.theory.equality.EqualityTheory) Context(com.sri.ai.grinder.sgdpllt.api.Context) Function(java.util.function.Function) ArrayList(java.util.ArrayList) Beta(com.google.common.annotations.Beta) LinkedHashMap(java.util.LinkedHashMap) List(java.util.List) TrueContext(com.sri.ai.grinder.sgdpllt.core.TrueContext) Map(java.util.Map) FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) BufferedReader(java.io.BufferedReader) EqualityTheory(com.sri.ai.grinder.sgdpllt.theory.equality.EqualityTheory) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap) Expression(com.sri.ai.expresso.api.Expression)

Aggregations

CartesianProductEnumeration (com.sri.ai.util.collect.CartesianProductEnumeration)5 ArrayList (java.util.ArrayList)5 FunctionTable (com.sri.ai.praise.lang.grounded.common.FunctionTable)4 Beta (com.google.common.annotations.Beta)3 LinkedHashMap (java.util.LinkedHashMap)3 List (java.util.List)3 Map (java.util.Map)3 Expression (com.sri.ai.expresso.api.Expression)2 MultiIndexQuantifierEliminator (com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator)2 FactorTable (com.sri.ai.praise.lang.grounded.markov.FactorTable)2 Pair (com.sri.ai.util.base.Pair)2 MixedRadixNumber (com.sri.ai.util.math.MixedRadixNumber)2 BigInteger (java.math.BigInteger)2 Collections (java.util.Collections)2 Function (java.util.function.Function)2 Expressions (com.sri.ai.expresso.helper.Expressions)1 FALSE (com.sri.ai.expresso.helper.Expressions.FALSE)1 TRUE (com.sri.ai.expresso.helper.Expressions.TRUE)1 Expressions.apply (com.sri.ai.expresso.helper.Expressions.apply)1 Expressions.makeSymbol (com.sri.ai.expresso.helper.Expressions.makeSymbol)1