Search in sources :

Example 1 with ConditionalProbabilityTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable in project aic-praise by aic-sri-international.

the class TransformMarkovToBayes method sumOutChildAndCreateCPT.

private static Pair<FactorTable, ConditionalProbabilityTable> sumOutChildAndCreateCPT(MarkovNetwork markov, Integer c, FactorTable cFactor) {
    List<Integer> parentVarIdxs = new ArrayList<>(cFactor.getVariableIndexes());
    parentVarIdxs.remove(c);
    int cCardinality = markov.cardinality(c);
    int cIdx = cFactor.getVariableIndexes().indexOf(c);
    FactorTable summedOut = null;
    FunctionTable cptTable = null;
    if (parentVarIdxs.size() == 0) {
        // No summed out factor as no remaining elements but still need to ensure the cpt values are normalized
        int idx = 0;
        List<Double> factorEntries = cFactor.getTable().getEntries();
        List<Double> cptEntries = new ArrayList<>(factorEntries.size());
        while (idx < factorEntries.size()) {
            double total = 0;
            for (int i = 0; i < cCardinality; i++) {
                total += factorEntries.get(idx + i);
            }
            for (int i = 0; i < cCardinality; i++) {
                cptEntries.add(factorEntries.get(idx + i) / total);
            }
            idx += cCardinality;
        }
        cptTable = new FunctionTable(Arrays.asList(cCardinality), cptEntries);
    } else {
        List<Integer> cptCardinalities = new ArrayList<>();
        List<Integer> parentVarCardinalities = new ArrayList<>();
        Map<Integer, Integer> parentCardIdxToCFactorIdx = new LinkedHashMap<>();
        int cFactorCardIdx = 0;
        for (Integer varIdx : cFactor.getVariableIndexes()) {
            if (!varIdx.equals(c)) {
                int card = markov.cardinality(varIdx);
                cptCardinalities.add(card);
                parentVarCardinalities.add(card);
                parentCardIdxToCFactorIdx.put(parentCardIdxToCFactorIdx.size(), cFactorCardIdx);
            }
            cFactorCardIdx++;
        }
        // The child index will be placed at the end of the table by convention
        cptCardinalities.add(markov.cardinality(c));
        List<Double> cptEntries = new ArrayList<>(FunctionTable.numEntriesFor(cptCardinalities));
        List<Double> parentSummedOutEntries = new ArrayList<>(FunctionTable.numEntriesFor(parentVarCardinalities));
        Map<Integer, Integer> assignmentMap = new LinkedHashMap<>();
        CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(parentVarCardinalities));
        while (cpe.hasMoreElements()) {
            List<Integer> parentValues = cpe.nextElement();
            assignmentMap.clear();
            for (int i = 0; i < parentValues.size(); i++) {
                assignmentMap.put(parentCardIdxToCFactorIdx.get(i), parentValues.get(i));
            }
            Double sum = cFactor.getTable().valueFor(assignmentMap);
            parentSummedOutEntries.add(sum);
            for (int i = 0; i < cCardinality; i++) {
                assignmentMap.put(cIdx, i);
                if (sum.equals(0.0)) {
                    // TODO - is this approach correct?
                    // NOTE: This prevents invalid models being generated by assigning an impossibly small probability to an event that should never occur
                    cptEntries.add(Double.MIN_NORMAL);
                } else {
                    cptEntries.add(cFactor.getTable().valueFor(assignmentMap) / sum);
                }
            }
        }
        summedOut = new FactorTable(parentVarIdxs, new FunctionTable(parentVarCardinalities, parentSummedOutEntries));
        cptTable = new FunctionTable(cptCardinalities, cptEntries);
    }
    ConditionalProbabilityTable cpt = new ConditionalProbabilityTable(parentVarIdxs, c, cptTable);
    Pair<FactorTable, ConditionalProbabilityTable> result = new Pair<>(summedOut, cpt);
    return result;
}
Also used : ConditionalProbabilityTable(com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap) FunctionTable(com.sri.ai.praise.core.representation.classbased.table.core.data.FunctionTable) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) Pair(com.sri.ai.util.base.Pair)

Example 2 with ConditionalProbabilityTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable in project aic-praise by aic-sri-international.

the class TransformMarkovToBayes method transform.

public static void transform(MarkovNetwork markov, BayesOutputListener bayesListener) {
    // Determine the random variable to factor associations
    Map<Integer, List<FactorTable>> randomVariableToFactors = new LinkedHashMap<>();
    List<FactorTable> factors = new ArrayList<>();
    for (int i = 0; i < markov.numberFactors(); i++) {
        FactorTable factor = markov.getFactor(i);
        factors.add(factor);
        for (Integer rvIdx : factor.getVariableIndexes()) {
            List<FactorTable> rvFactors = Util.getValuePossiblyCreatingIt(randomVariableToFactors, rvIdx, ArrayList.class);
            rvFactors.add(factor);
        }
    }
    while (!randomVariableToFactors.isEmpty()) {
        // STEP 1:
        // Pick a variable such that number of random variable neighbors
        // (that is, other vars sharing factors) is minimal.
        Integer c = pickMinimal(factors);
        // STEP 2:
        // Multiply all factors containing this var - remove them from
        // the factor graph - to give an unnormalized CPT phi for the var
        List<FactorTable> factorsContainingC = randomVariableToFactors.get(c);
        Set<Integer> p = new LinkedHashSet<>();
        for (FactorTable f : factorsContainingC) {
            p.addAll(f.getVariableIndexes());
        }
        FactorTable cFactor = newMegaFactor(markov, p, factorsContainingC);
        // Now ensure c is not included in p
        p.remove(c);
        randomVariableToFactors.remove(c);
        factors.removeAll(factorsContainingC);
        for (Integer pv : p) {
            randomVariableToFactors.get(pv).removeAll(factorsContainingC);
        }
        // STEP 3:
        // In the unnormalized CPT phi, let C be the child
        // (latest in the ordering) and P the other variables in it.
        // For each assignment to parents P, compute a new factor on P,
        // Z(P) = sum_{C} phi(C,P).
        // Replace each entry phi(C,P) by phi(C,P)/Z(P).
        // Store the now normalized CPT in the Bayesian network, and
        // add factor Z to the factor network.
        Pair<FactorTable, ConditionalProbabilityTable> zFactorAndcCPT = sumOutChildAndCreateCPT(markov, c, cFactor);
        FactorTable zFactor = zFactorAndcCPT.first;
        ConditionalProbabilityTable cCPT = zFactorAndcCPT.second;
        bayesListener.newCPT(cCPT);
        if (zFactor != null) {
            factors.add(zFactor);
            for (Integer pv : p) {
                randomVariableToFactors.get(pv).add(zFactor);
            }
        }
    }
// Repeat for the remaining of the factor graph
}
Also used : LinkedHashSet(java.util.LinkedHashSet) ConditionalProbabilityTable(com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) LinkedHashMap(java.util.LinkedHashMap)

Aggregations

ConditionalProbabilityTable (com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable)2 FactorTable (com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable)2 ArrayList (java.util.ArrayList)2 LinkedHashMap (java.util.LinkedHashMap)2 FunctionTable (com.sri.ai.praise.core.representation.classbased.table.core.data.FunctionTable)1 Pair (com.sri.ai.util.base.Pair)1 CartesianProductEnumeration (com.sri.ai.util.collect.CartesianProductEnumeration)1 LinkedHashSet (java.util.LinkedHashSet)1 List (java.util.List)1