use of com.sri.ai.util.math.MixedRadixNumber in project aic-praise by aic-sri-international.
the class HOGModelGrounding method fullGrounding.
/**
* @param factor
* @param randomVariablesInFactor
* @param listener
* @param randomVariableNameToTypeSizeAndUniqueConstants
* @param typeToValues
* @param inferencer
* @param context
*/
private static void fullGrounding(Expression factor, List<Expression> randomVariablesInFactor, Listener listener, Map<Expression, Triple<Expression, Integer, List<Expression>>> randomVariableNameToTypeSizeAndUniqueConstants, Map<Expression, List<Expression>> typeToValues, InferenceForFactorGraphAndEvidence inferencer, Context context) {
int[] radices = new int[randomVariablesInFactor.size()];
List<List<Expression>> factorRandomVariableTypeValues = new ArrayList<>();
for (int i = 0; i < randomVariablesInFactor.size(); i++) {
Expression randomVariable = randomVariablesInFactor.get(i);
Expression type = randomVariableNameToTypeSizeAndUniqueConstants.get(randomVariable).first;
radices[i] = randomVariableNameToTypeSizeAndUniqueConstants.get(randomVariable).second;
factorRandomVariableTypeValues.add(typeToValues.get(type));
}
boolean didIncrement = true;
MixedRadixNumber mrn = new MixedRadixNumber(BigInteger.ZERO, radices);
int numberFactorValues = mrn.getMaxAllowedValue().intValue() + 1;
do {
Expression groundedFactor = factor;
for (int i = 0; i < randomVariablesInFactor.size(); i++) {
int valueIndex = mrn.getCurrentNumeralValue(i);
groundedFactor = groundedFactor.replaceAllOccurrences(randomVariablesInFactor.get(i), factorRandomVariableTypeValues.get(i).get(valueIndex), context);
}
Expression value = inferencer.simplify(groundedFactor, context);
// Expression value = inferencer.evaluate(groundedFactor);
if (!Expressions.isNumber(value)) {
throw new IllegalStateException("Unable to compute a number for the grounded factor [" + groundedFactor + "], instead got:" + value);
}
boolean isFirstValue = mrn.getValue().intValue() == 0;
boolean isLastValue = mrn.getValue().intValue() == numberFactorValues - 1;
listener.factorValue(numberFactorValues, isFirstValue, isLastValue, value.rationalValue());
if (didIncrement = mrn.canIncrement()) {
mrn.increment();
}
} while (didIncrement);
}
Aggregations