use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.
the class XFormMarkovToBayes method sumOutChildAndCreateCPT.
private static Pair<FactorTable, ConditionalProbabilityTable> sumOutChildAndCreateCPT(MarkovNetwork markov, Integer c, FactorTable cFactor) {
List<Integer> parentVarIdxs = new ArrayList<>(cFactor.getVariableIndexes());
parentVarIdxs.remove(c);
int cCardinality = markov.cardinality(c);
int cIdx = cFactor.getVariableIndexes().indexOf(c);
FactorTable summedOut = null;
FunctionTable cptTable = null;
if (parentVarIdxs.size() == 0) {
// No summed out factor as no remaining elements but still need to ensure the cpt values are normalized
int idx = 0;
List<Double> factorEntries = cFactor.getTable().getEntries();
List<Double> cptEntries = new ArrayList<>(factorEntries.size());
while (idx < factorEntries.size()) {
double total = 0;
for (int i = 0; i < cCardinality; i++) {
total += factorEntries.get(idx + i);
}
for (int i = 0; i < cCardinality; i++) {
cptEntries.add(factorEntries.get(idx + i) / total);
}
idx += cCardinality;
}
cptTable = new FunctionTable(Arrays.asList(cCardinality), cptEntries);
} else {
List<Integer> cptCardinalities = new ArrayList<>();
List<Integer> parentVarCardinalities = new ArrayList<>();
Map<Integer, Integer> parentCardIdxToCFactorIdx = new LinkedHashMap<>();
int cFactorCardIdx = 0;
for (Integer varIdx : cFactor.getVariableIndexes()) {
if (!varIdx.equals(c)) {
int card = markov.cardinality(varIdx);
cptCardinalities.add(card);
parentVarCardinalities.add(card);
parentCardIdxToCFactorIdx.put(parentCardIdxToCFactorIdx.size(), cFactorCardIdx);
}
cFactorCardIdx++;
}
// The child index will be placed at the end of the table by convention
cptCardinalities.add(markov.cardinality(c));
List<Double> cptEntries = new ArrayList<>(FunctionTable.numEntriesFor(cptCardinalities));
List<Double> parentSummedOutEntries = new ArrayList<>(FunctionTable.numEntriesFor(parentVarCardinalities));
Map<Integer, Integer> assignmentMap = new LinkedHashMap<>();
CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(parentVarCardinalities));
while (cpe.hasMoreElements()) {
List<Integer> parentValues = cpe.nextElement();
assignmentMap.clear();
for (int i = 0; i < parentValues.size(); i++) {
assignmentMap.put(parentCardIdxToCFactorIdx.get(i), parentValues.get(i));
}
Double sum = cFactor.getTable().valueFor(assignmentMap);
parentSummedOutEntries.add(sum);
for (int i = 0; i < cCardinality; i++) {
assignmentMap.put(cIdx, i);
if (sum.equals(0.0)) {
// TODO - is this approach correct?
// NOTE: This prevents invalid models being generated by assigning an impossibly small probability to an event that should never occur
cptEntries.add(Double.MIN_NORMAL);
} else {
cptEntries.add(cFactor.getTable().valueFor(assignmentMap) / sum);
}
}
}
summedOut = new FactorTable(parentVarIdxs, new FunctionTable(parentVarCardinalities, parentSummedOutEntries));
cptTable = new FunctionTable(cptCardinalities, cptEntries);
}
ConditionalProbabilityTable cpt = new ConditionalProbabilityTable(parentVarIdxs, c, cptTable);
Pair<FactorTable, ConditionalProbabilityTable> result = new Pair<>(summedOut, cpt);
return result;
}
use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.
the class FunctionTable method valueFor.
public Double valueFor(Map<Integer, Integer> assignmentMap) {
double result = 0;
// If have all assignments then get the entry straight off as opposed to summing them
if (assignmentMap.size() == varCardinalities.size()) {
result = entryFor(IntStream.range(0, varCardinalities.size()).boxed().map(i -> assignmentMap.get(i)).collect(Collectors.toList()));
} else {
// More than 1 entry value needs to be summed up
List<List<Integer>> possibleValues = new ArrayList<>();
for (int i = 0; i < varCardinalities.size(); i++) {
Integer value = assignmentMap.get(i);
if (value == null) {
possibleValues.add(IntStream.range(0, varCardinalities.get(i)).boxed().collect(Collectors.toList()));
} else {
possibleValues.add(Arrays.asList(value));
}
}
CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(possibleValues);
while (cpe.hasMoreElements()) {
result += entryFor(cpe.nextElement());
}
}
return result;
}
use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.
the class XFormMarkovToBayes method newMegaFactor.
private static FactorTable newMegaFactor(MarkovNetwork markov, Set<Integer> varIdxs, List<FactorTable> factorsContainingC) {
List<Integer> variableIndexes = new ArrayList<>(varIdxs);
Map<Integer, Integer> varIdxToOffset = new LinkedHashMap<>();
List<Integer> varCardinalities = new ArrayList<>();
for (Integer varIdx : variableIndexes) {
varIdxToOffset.put(varIdx, varIdxToOffset.size());
varCardinalities.add(markov.cardinality(varIdx));
}
List<Double> entries = new ArrayList<>(FunctionTable.numEntriesFor(varCardinalities));
List<Integer> varValues = new ArrayList<>();
CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(varCardinalities));
while (cpe.hasMoreElements()) {
List<Integer> assignments = cpe.nextElement();
double product = 1;
for (FactorTable ft : factorsContainingC) {
varValues.clear();
for (Integer varIdx : ft.getVariableIndexes()) {
varValues.add(assignments.get(varIdxToOffset.get(varIdx)));
}
product *= ft.getTable().entryFor(varValues);
}
entries.add(product);
}
FactorTable result = new FactorTable(variableIndexes, new FunctionTable(varCardinalities, entries));
return result;
}
use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.
the class TranslationOfTableToInequalities method constructGenericTableExpressionUsingInequalities.
/**
* Returns an {@link Expression} equivalent to a given {@link FunctionTable} but in the form of a decision tree
* (so hopefully more compact) using inequalities.
* @param functionTable
* @param solverListener if not null, invoked on solver used for compilation,
* before and after compilation is performed; returned solver from "before" invocation is used (it may be the same one used as argument, of course).
* @return
*/
public static Expression constructGenericTableExpressionUsingInequalities(FunctionTable functionTable, Function<MultiIndexQuantifierEliminator, MultiIndexQuantifierEliminator> solverListener) {
// the strategy in this method is the following:
// we collect all the contiguous indices sub-sets of the function table sharing their function value.
// They are kept in a map from each value to a list of indices sub-sets with that value.
//
// Then, we sort these groups of indices sub-sets by the sum of their sizes (number of entries), from smallest to largest.
// This will help us later to create an expression that tests for the largest groups first.
//
// Finally, we create an if-then-else expression, starting from the leaf (least common value).
// For each group of indices sub-sets with the same value, we obtain an inequalities expression describing
// the conditions for a variable assignment to be in that indices sub-set of the function table.
// Each portion generates a conjunction, and the group of portions generates a disjunction.
//
// The resulting if-then-else expression is linearly organized (only else clauses have nested if-then-else expressions).
// A more balanced (and thus efficient) representation is obtained by compiling it using SGDPLL(T).
Map<Double, List<FunctionTableIndicesSubSet>> functionValuesAndCorrespondingIndicesSubSet = map();
Double currentSubSetFunctionValueIfAny = null;
List<Integer> firstIndicesOfCurrentSubSetIfAny = null;
List<Integer> previousIndices = null;
List<Integer> indices = null;
CartesianProductEnumeration<Integer> cartesianProduct = new CartesianProductEnumeration<>(UAIUtil.cardinalityValues(functionTable));
while (cartesianProduct.hasMoreElements()) {
previousIndices = indices;
indices = new ArrayList<>(cartesianProduct.nextElement());
Double functionValue = Math.round(functionTable.entryFor(indices) * 100) / 100.0;
boolean hitNewFunctionValue = currentSubSetFunctionValueIfAny == null || !functionValue.equals(currentSubSetFunctionValueIfAny);
if (hitNewFunctionValue) {
storeIndicesSubSetOnAllVariables(functionTable, firstIndicesOfCurrentSubSetIfAny, previousIndices, currentSubSetFunctionValueIfAny, functionValuesAndCorrespondingIndicesSubSet);
// get information for next indices sub-set
currentSubSetFunctionValueIfAny = functionValue;
firstIndicesOfCurrentSubSetIfAny = indices;
}
}
previousIndices = indices;
storeIndicesSubSetOnAllVariables(functionTable, firstIndicesOfCurrentSubSetIfAny, previousIndices, currentSubSetFunctionValueIfAny, functionValuesAndCorrespondingIndicesSubSet);
// we sort (by using TreeMap) lists of indices sub-set with the same function value from those with smaller to greater sizes,
// and form the final expression backwards, thus prioritizing larger sub-sets
// whose conditions will be more often satisfied and leading to greater simplifications during inference.
List<Pair<BigInteger, List<FunctionTableIndicesSubSet>>> listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue = new ArrayList<>(functionValuesAndCorrespondingIndicesSubSet.size());
for (Map.Entry<Double, List<FunctionTableIndicesSubSet>> functionValueAndIndicesSubSet : functionValuesAndCorrespondingIndicesSubSet.entrySet()) {
List<FunctionTableIndicesSubSet> indicesSubSetsWithSameFunctionValue = functionValueAndIndicesSubSet.getValue();
BigInteger sumOfSizes = BigInteger.ZERO;
for (FunctionTableIndicesSubSet indicesSubSet : indicesSubSetsWithSameFunctionValue) {
sumOfSizes = sumOfSizes.add(indicesSubSet.size());
}
listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue.add(Pair.make(sumOfSizes, indicesSubSetsWithSameFunctionValue));
}
Collections.sort(listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue, (Comparator<? super Pair<BigInteger, List<FunctionTableIndicesSubSet>>>) (p1, p2) -> p1.first.compareTo(p2.first));
List<List<FunctionTableIndicesSubSet>> listsOfIndicesSubSetsWithSameFunctionValue = mapIntoList(listOfPairsOfSizeAndListsOfIndicesSubSetsWithSameFunctionValue, p -> p.second);
Iterator<List<FunctionTableIndicesSubSet>> listsOfIndicesSubSetsWithSameFunctionValueIterator = listsOfIndicesSubSetsWithSameFunctionValue.iterator();
List<FunctionTableIndicesSubSet> firstListOfIndicesSubSets = listsOfIndicesSubSetsWithSameFunctionValueIterator.next();
Double valueOfFirstListOfIndicesSubSets = getFirstOrNull(firstListOfIndicesSubSets).getFunctionValue();
Expression currentExpression = makeSymbol(valueOfFirstListOfIndicesSubSets);
while (listsOfIndicesSubSetsWithSameFunctionValueIterator.hasNext()) {
List<FunctionTableIndicesSubSet> indicesSubSetsWithSameFunctionValue = listsOfIndicesSubSetsWithSameFunctionValueIterator.next();
Expression functionValueOfIndicesSubSetsWithSameFunctionValue = makeSymbol(getFirstOrNull(indicesSubSetsWithSameFunctionValue).getFunctionValue());
Expression conditionForThisFunctionValue = Or.make(mapIntoList(indicesSubSetsWithSameFunctionValue, TranslationOfTableToInequalities::getInequalitiesExpressionForFunctionTableIndicesSubSet));
currentExpression = IfThenElse.make(conditionForThisFunctionValue, functionValueOfIndicesSubSetsWithSameFunctionValue, currentExpression);
}
return currentExpression;
}
use of com.sri.ai.util.collect.CartesianProductEnumeration in project aic-praise by aic-sri-international.
the class UAIUtil method constructGenericTableExpressionUsingEqualities.
/**
* Returns an {@link Expression} equivalent to a given {@link FunctionTable} but in the form of a decision tree
* (so hopefully more compact) using equalities.
* @param functionTable
* @param solverListener if not null, invoked on solver used for compilation, before and after compilation is performed; returned solver from "before" invocation is used (it may be the same one used as argument, of course).
* @return
*/
public static Expression constructGenericTableExpressionUsingEqualities(FunctionTable functionTable, Function<MultiIndexQuantifierEliminator, MultiIndexQuantifierEliminator> solverListener) {
StringBuilder table = new StringBuilder();
CartesianProductEnumeration<Integer> cartesianProduct = new CartesianProductEnumeration<>(cardinalityValues(functionTable));
int counter = 0;
while (cartesianProduct.hasMoreElements()) {
counter++;
List<Integer> values = cartesianProduct.nextElement();
Double entryValue = functionTable.entryFor(values);
if (counter == cartesianProduct.size().intValue()) {
// i.e. final value
table.append(entryValue);
} else {
table.append("if ");
for (int i = 0; i < values.size(); i++) {
if (i > 0) {
table.append(" and ");
}
String value = genericConstantValueForVariable(values.get(i), i, functionTable.cardinality(i));
if (value.equals("true")) {
table.append(genericVariableName(i));
} else if (value.equals("false")) {
table.append("not " + genericVariableName(i));
} else {
table.append(genericVariableName(i));
table.append(" = ");
table.append(value);
}
}
table.append(" then ");
table.append(entryValue);
table.append(" else ");
}
}
Expression inputExpression = Expressions.parse(table.toString());
Function<Integer, Integer> cardinalityOfIthVariable = i -> functionTable.cardinality(i);
Map<String, String> mapFromCategoricalTypeNameToSizeString = new LinkedHashMap<>();
Map<String, String> mapFromVariableNameToTypeName = new LinkedHashMap<>();
Map<String, String> mapFromUniquelyNamedConstantToTypeName = new LinkedHashMap<>();
for (int i = 0; i < functionTable.numberVariables(); i++) {
String typeName = genericTypeNameForVariable(i, cardinalityOfIthVariable.apply(i));
mapFromCategoricalTypeNameToSizeString.put(typeName, "" + cardinalityOfIthVariable.apply(i));
mapFromVariableNameToTypeName.put(genericVariableName(i), typeName);
for (int j = 0; j != functionTable.cardinality(i); j++) {
String jThConstant = genericConstantValueForVariable(j, i, functionTable.cardinality(i));
mapFromUniquelyNamedConstantToTypeName.put(jThConstant, typeName);
}
}
com.sri.ai.grinder.sgdpllt.api.Theory theory = new EqualityTheory(true, true);
Expression result = Compilation.compile(inputExpression, theory, mapFromVariableNameToTypeName, mapFromUniquelyNamedConstantToTypeName, mapFromCategoricalTypeNameToSizeString, list(), solverListener);
return result;
}
Aggregations