use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.
the class TransformMarkovToBayes method newMegaFactor.
private static FactorTable newMegaFactor(MarkovNetwork markov, Set<Integer> varIdxs, List<FactorTable> factorsContainingC) {
List<Integer> variableIndexes = new ArrayList<>(varIdxs);
Map<Integer, Integer> varIdxToOffset = new LinkedHashMap<>();
List<Integer> varCardinalities = new ArrayList<>();
for (Integer varIdx : variableIndexes) {
varIdxToOffset.put(varIdx, varIdxToOffset.size());
varCardinalities.add(markov.cardinality(varIdx));
}
List<Double> entries = new ArrayList<>(FunctionTable.numEntriesFor(varCardinalities));
List<Integer> varValues = new ArrayList<>();
CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(varCardinalities));
while (cpe.hasMoreElements()) {
List<Integer> assignments = cpe.nextElement();
double product = 1;
for (FactorTable ft : factorsContainingC) {
varValues.clear();
for (Integer varIdx : ft.getVariableIndexes()) {
varValues.add(assignments.get(varIdxToOffset.get(varIdx)));
}
product *= ft.getTable().entryFor(varValues);
}
entries.add(product);
}
FactorTable result = new FactorTable(variableIndexes, new FunctionTable(varCardinalities, entries));
return result;
}
use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.
the class TableFactorNetwork method addVariablesToMap.
private static LinkedHashMap<Integer, TableVariable> addVariablesToMap(UAIModel model) {
LinkedHashMap<Integer, TableVariable> mapFromVariableIndexToVariable = new LinkedHashMap<>();
int nFactors = model.numberFactors();
for (int i = 0; i < nFactors; i++) {
FactorTable f = model.getFactor(i);
addVariablesFromFactorToMap(f, mapFromVariableIndexToVariable);
}
return mapFromVariableIndexToVariable;
}
use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.
the class TransformMarkovToBayes method sumOutChildAndCreateCPT.
private static Pair<FactorTable, ConditionalProbabilityTable> sumOutChildAndCreateCPT(MarkovNetwork markov, Integer c, FactorTable cFactor) {
List<Integer> parentVarIdxs = new ArrayList<>(cFactor.getVariableIndexes());
parentVarIdxs.remove(c);
int cCardinality = markov.cardinality(c);
int cIdx = cFactor.getVariableIndexes().indexOf(c);
FactorTable summedOut = null;
FunctionTable cptTable = null;
if (parentVarIdxs.size() == 0) {
// No summed out factor as no remaining elements but still need to ensure the cpt values are normalized
int idx = 0;
List<Double> factorEntries = cFactor.getTable().getEntries();
List<Double> cptEntries = new ArrayList<>(factorEntries.size());
while (idx < factorEntries.size()) {
double total = 0;
for (int i = 0; i < cCardinality; i++) {
total += factorEntries.get(idx + i);
}
for (int i = 0; i < cCardinality; i++) {
cptEntries.add(factorEntries.get(idx + i) / total);
}
idx += cCardinality;
}
cptTable = new FunctionTable(Arrays.asList(cCardinality), cptEntries);
} else {
List<Integer> cptCardinalities = new ArrayList<>();
List<Integer> parentVarCardinalities = new ArrayList<>();
Map<Integer, Integer> parentCardIdxToCFactorIdx = new LinkedHashMap<>();
int cFactorCardIdx = 0;
for (Integer varIdx : cFactor.getVariableIndexes()) {
if (!varIdx.equals(c)) {
int card = markov.cardinality(varIdx);
cptCardinalities.add(card);
parentVarCardinalities.add(card);
parentCardIdxToCFactorIdx.put(parentCardIdxToCFactorIdx.size(), cFactorCardIdx);
}
cFactorCardIdx++;
}
// The child index will be placed at the end of the table by convention
cptCardinalities.add(markov.cardinality(c));
List<Double> cptEntries = new ArrayList<>(FunctionTable.numEntriesFor(cptCardinalities));
List<Double> parentSummedOutEntries = new ArrayList<>(FunctionTable.numEntriesFor(parentVarCardinalities));
Map<Integer, Integer> assignmentMap = new LinkedHashMap<>();
CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(parentVarCardinalities));
while (cpe.hasMoreElements()) {
List<Integer> parentValues = cpe.nextElement();
assignmentMap.clear();
for (int i = 0; i < parentValues.size(); i++) {
assignmentMap.put(parentCardIdxToCFactorIdx.get(i), parentValues.get(i));
}
Double sum = cFactor.getTable().valueFor(assignmentMap);
parentSummedOutEntries.add(sum);
for (int i = 0; i < cCardinality; i++) {
assignmentMap.put(cIdx, i);
if (sum.equals(0.0)) {
// TODO - is this approach correct?
// NOTE: This prevents invalid models being generated by assigning an impossibly small probability to an event that should never occur
cptEntries.add(Double.MIN_NORMAL);
} else {
cptEntries.add(cFactor.getTable().valueFor(assignmentMap) / sum);
}
}
}
summedOut = new FactorTable(parentVarIdxs, new FunctionTable(parentVarCardinalities, parentSummedOutEntries));
cptTable = new FunctionTable(cptCardinalities, cptEntries);
}
ConditionalProbabilityTable cpt = new ConditionalProbabilityTable(parentVarIdxs, c, cptTable);
Pair<FactorTable, ConditionalProbabilityTable> result = new Pair<>(summedOut, cpt);
return result;
}
use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.
the class TransformMarkovToBayes method transform.
public static void transform(MarkovNetwork markov, BayesOutputListener bayesListener) {
// Determine the random variable to factor associations
Map<Integer, List<FactorTable>> randomVariableToFactors = new LinkedHashMap<>();
List<FactorTable> factors = new ArrayList<>();
for (int i = 0; i < markov.numberFactors(); i++) {
FactorTable factor = markov.getFactor(i);
factors.add(factor);
for (Integer rvIdx : factor.getVariableIndexes()) {
List<FactorTable> rvFactors = Util.getValuePossiblyCreatingIt(randomVariableToFactors, rvIdx, ArrayList.class);
rvFactors.add(factor);
}
}
while (!randomVariableToFactors.isEmpty()) {
// STEP 1:
// Pick a variable such that number of random variable neighbors
// (that is, other vars sharing factors) is minimal.
Integer c = pickMinimal(factors);
// STEP 2:
// Multiply all factors containing this var - remove them from
// the factor graph - to give an unnormalized CPT phi for the var
List<FactorTable> factorsContainingC = randomVariableToFactors.get(c);
Set<Integer> p = new LinkedHashSet<>();
for (FactorTable f : factorsContainingC) {
p.addAll(f.getVariableIndexes());
}
FactorTable cFactor = newMegaFactor(markov, p, factorsContainingC);
// Now ensure c is not included in p
p.remove(c);
randomVariableToFactors.remove(c);
factors.removeAll(factorsContainingC);
for (Integer pv : p) {
randomVariableToFactors.get(pv).removeAll(factorsContainingC);
}
// STEP 3:
// In the unnormalized CPT phi, let C be the child
// (latest in the ordering) and P the other variables in it.
// For each assignment to parents P, compute a new factor on P,
// Z(P) = sum_{C} phi(C,P).
// Replace each entry phi(C,P) by phi(C,P)/Z(P).
// Store the now normalized CPT in the Bayesian network, and
// add factor Z to the factor network.
Pair<FactorTable, ConditionalProbabilityTable> zFactorAndcCPT = sumOutChildAndCreateCPT(markov, c, cFactor);
FactorTable zFactor = zFactorAndcCPT.first;
ConditionalProbabilityTable cCPT = zFactorAndcCPT.second;
bayesListener.newCPT(cCPT);
if (zFactor != null) {
factors.add(zFactor);
for (Integer pv : p) {
randomVariableToFactors.get(pv).add(zFactor);
}
}
}
// Repeat for the remaining of the factor graph
}
use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.
the class TransformMarkovToBayes method pickMinimal.
//
// PRIVATE
//
private static Integer pickMinimal(List<FactorTable> factors) {
if (factors.size() == 0) {
throw new IllegalArgumentException("Need > 0 factors to pick a minimum");
}
Map<Integer, Set<Integer>> randomVariableNeighbors = new LinkedHashMap<>();
for (FactorTable factor : factors) {
for (Integer rv : factor.getVariableIndexes()) {
Set<Integer> neighbors = Util.getValuePossiblyCreatingIt(randomVariableNeighbors, rv, LinkedHashSet.class);
// NOTE: include self for efficiency as all counts will end up being + 1
neighbors.addAll(factor.getVariableIndexes());
}
}
Map.Entry<Integer, Set<Integer>> smallestEntry = randomVariableNeighbors.entrySet().stream().min((e1, e2) -> Integer.compare(e1.getValue().size(), e2.getValue().size())).get();
if (smallestEntry.getValue().size() > MAX_NUM_ALLOWED_PARENTS_IN_CPT) {
throw new IllegalStateException("Too large a CPT will need be generated as #parents=" + (smallestEntry.getValue().size() - 1));
}
Integer result = smallestEntry.getKey();
return result;
}
Aggregations