Search in sources :

Example 1 with FactorTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.

the class TransformMarkovToBayes method newMegaFactor.

private static FactorTable newMegaFactor(MarkovNetwork markov, Set<Integer> varIdxs, List<FactorTable> factorsContainingC) {
    List<Integer> variableIndexes = new ArrayList<>(varIdxs);
    Map<Integer, Integer> varIdxToOffset = new LinkedHashMap<>();
    List<Integer> varCardinalities = new ArrayList<>();
    for (Integer varIdx : variableIndexes) {
        varIdxToOffset.put(varIdx, varIdxToOffset.size());
        varCardinalities.add(markov.cardinality(varIdx));
    }
    List<Double> entries = new ArrayList<>(FunctionTable.numEntriesFor(varCardinalities));
    List<Integer> varValues = new ArrayList<>();
    CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(varCardinalities));
    while (cpe.hasMoreElements()) {
        List<Integer> assignments = cpe.nextElement();
        double product = 1;
        for (FactorTable ft : factorsContainingC) {
            varValues.clear();
            for (Integer varIdx : ft.getVariableIndexes()) {
                varValues.add(assignments.get(varIdxToOffset.get(varIdx)));
            }
            product *= ft.getTable().entryFor(varValues);
        }
        entries.add(product);
    }
    FactorTable result = new FactorTable(variableIndexes, new FunctionTable(varCardinalities, entries));
    return result;
}
Also used : FunctionTable(com.sri.ai.praise.core.representation.classbased.table.core.data.FunctionTable) ArrayList(java.util.ArrayList) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap)

Example 2 with FactorTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.

the class TableFactorNetwork method addVariablesToMap.

private static LinkedHashMap<Integer, TableVariable> addVariablesToMap(UAIModel model) {
    LinkedHashMap<Integer, TableVariable> mapFromVariableIndexToVariable = new LinkedHashMap<>();
    int nFactors = model.numberFactors();
    for (int i = 0; i < nFactors; i++) {
        FactorTable f = model.getFactor(i);
        addVariablesFromFactorToMap(f, mapFromVariableIndexToVariable);
    }
    return mapFromVariableIndexToVariable;
}
Also used : FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) LinkedHashMap(java.util.LinkedHashMap)

Example 3 with FactorTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.

the class TransformMarkovToBayes method sumOutChildAndCreateCPT.

private static Pair<FactorTable, ConditionalProbabilityTable> sumOutChildAndCreateCPT(MarkovNetwork markov, Integer c, FactorTable cFactor) {
    List<Integer> parentVarIdxs = new ArrayList<>(cFactor.getVariableIndexes());
    parentVarIdxs.remove(c);
    int cCardinality = markov.cardinality(c);
    int cIdx = cFactor.getVariableIndexes().indexOf(c);
    FactorTable summedOut = null;
    FunctionTable cptTable = null;
    if (parentVarIdxs.size() == 0) {
        // No summed out factor as no remaining elements but still need to ensure the cpt values are normalized
        int idx = 0;
        List<Double> factorEntries = cFactor.getTable().getEntries();
        List<Double> cptEntries = new ArrayList<>(factorEntries.size());
        while (idx < factorEntries.size()) {
            double total = 0;
            for (int i = 0; i < cCardinality; i++) {
                total += factorEntries.get(idx + i);
            }
            for (int i = 0; i < cCardinality; i++) {
                cptEntries.add(factorEntries.get(idx + i) / total);
            }
            idx += cCardinality;
        }
        cptTable = new FunctionTable(Arrays.asList(cCardinality), cptEntries);
    } else {
        List<Integer> cptCardinalities = new ArrayList<>();
        List<Integer> parentVarCardinalities = new ArrayList<>();
        Map<Integer, Integer> parentCardIdxToCFactorIdx = new LinkedHashMap<>();
        int cFactorCardIdx = 0;
        for (Integer varIdx : cFactor.getVariableIndexes()) {
            if (!varIdx.equals(c)) {
                int card = markov.cardinality(varIdx);
                cptCardinalities.add(card);
                parentVarCardinalities.add(card);
                parentCardIdxToCFactorIdx.put(parentCardIdxToCFactorIdx.size(), cFactorCardIdx);
            }
            cFactorCardIdx++;
        }
        // The child index will be placed at the end of the table by convention
        cptCardinalities.add(markov.cardinality(c));
        List<Double> cptEntries = new ArrayList<>(FunctionTable.numEntriesFor(cptCardinalities));
        List<Double> parentSummedOutEntries = new ArrayList<>(FunctionTable.numEntriesFor(parentVarCardinalities));
        Map<Integer, Integer> assignmentMap = new LinkedHashMap<>();
        CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(parentVarCardinalities));
        while (cpe.hasMoreElements()) {
            List<Integer> parentValues = cpe.nextElement();
            assignmentMap.clear();
            for (int i = 0; i < parentValues.size(); i++) {
                assignmentMap.put(parentCardIdxToCFactorIdx.get(i), parentValues.get(i));
            }
            Double sum = cFactor.getTable().valueFor(assignmentMap);
            parentSummedOutEntries.add(sum);
            for (int i = 0; i < cCardinality; i++) {
                assignmentMap.put(cIdx, i);
                if (sum.equals(0.0)) {
                    // TODO - is this approach correct?
                    // NOTE: This prevents invalid models being generated by assigning an impossibly small probability to an event that should never occur
                    cptEntries.add(Double.MIN_NORMAL);
                } else {
                    cptEntries.add(cFactor.getTable().valueFor(assignmentMap) / sum);
                }
            }
        }
        summedOut = new FactorTable(parentVarIdxs, new FunctionTable(parentVarCardinalities, parentSummedOutEntries));
        cptTable = new FunctionTable(cptCardinalities, cptEntries);
    }
    ConditionalProbabilityTable cpt = new ConditionalProbabilityTable(parentVarIdxs, c, cptTable);
    Pair<FactorTable, ConditionalProbabilityTable> result = new Pair<>(summedOut, cpt);
    return result;
}
Also used : ConditionalProbabilityTable(com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap) FunctionTable(com.sri.ai.praise.core.representation.classbased.table.core.data.FunctionTable) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) Pair(com.sri.ai.util.base.Pair)

Example 4 with FactorTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.

the class TransformMarkovToBayes method transform.

public static void transform(MarkovNetwork markov, BayesOutputListener bayesListener) {
    // Determine the random variable to factor associations
    Map<Integer, List<FactorTable>> randomVariableToFactors = new LinkedHashMap<>();
    List<FactorTable> factors = new ArrayList<>();
    for (int i = 0; i < markov.numberFactors(); i++) {
        FactorTable factor = markov.getFactor(i);
        factors.add(factor);
        for (Integer rvIdx : factor.getVariableIndexes()) {
            List<FactorTable> rvFactors = Util.getValuePossiblyCreatingIt(randomVariableToFactors, rvIdx, ArrayList.class);
            rvFactors.add(factor);
        }
    }
    while (!randomVariableToFactors.isEmpty()) {
        // STEP 1:
        // Pick a variable such that number of random variable neighbors
        // (that is, other vars sharing factors) is minimal.
        Integer c = pickMinimal(factors);
        // STEP 2:
        // Multiply all factors containing this var - remove them from
        // the factor graph - to give an unnormalized CPT phi for the var
        List<FactorTable> factorsContainingC = randomVariableToFactors.get(c);
        Set<Integer> p = new LinkedHashSet<>();
        for (FactorTable f : factorsContainingC) {
            p.addAll(f.getVariableIndexes());
        }
        FactorTable cFactor = newMegaFactor(markov, p, factorsContainingC);
        // Now ensure c is not included in p
        p.remove(c);
        randomVariableToFactors.remove(c);
        factors.removeAll(factorsContainingC);
        for (Integer pv : p) {
            randomVariableToFactors.get(pv).removeAll(factorsContainingC);
        }
        // STEP 3:
        // In the unnormalized CPT phi, let C be the child
        // (latest in the ordering) and P the other variables in it.
        // For each assignment to parents P, compute a new factor on P,
        // Z(P) = sum_{C} phi(C,P).
        // Replace each entry phi(C,P) by phi(C,P)/Z(P).
        // Store the now normalized CPT in the Bayesian network, and
        // add factor Z to the factor network.
        Pair<FactorTable, ConditionalProbabilityTable> zFactorAndcCPT = sumOutChildAndCreateCPT(markov, c, cFactor);
        FactorTable zFactor = zFactorAndcCPT.first;
        ConditionalProbabilityTable cCPT = zFactorAndcCPT.second;
        bayesListener.newCPT(cCPT);
        if (zFactor != null) {
            factors.add(zFactor);
            for (Integer pv : p) {
                randomVariableToFactors.get(pv).add(zFactor);
            }
        }
    }
// Repeat for the remaining of the factor graph
}
Also used : LinkedHashSet(java.util.LinkedHashSet) ConditionalProbabilityTable(com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) LinkedHashMap(java.util.LinkedHashMap)

Example 5 with FactorTable

use of com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable in project aic-praise by aic-sri-international.

the class TransformMarkovToBayes method pickMinimal.

// 
// PRIVATE
// 
private static Integer pickMinimal(List<FactorTable> factors) {
    if (factors.size() == 0) {
        throw new IllegalArgumentException("Need > 0 factors to pick a minimum");
    }
    Map<Integer, Set<Integer>> randomVariableNeighbors = new LinkedHashMap<>();
    for (FactorTable factor : factors) {
        for (Integer rv : factor.getVariableIndexes()) {
            Set<Integer> neighbors = Util.getValuePossiblyCreatingIt(randomVariableNeighbors, rv, LinkedHashSet.class);
            // NOTE: include self for efficiency as all counts will end up being + 1
            neighbors.addAll(factor.getVariableIndexes());
        }
    }
    Map.Entry<Integer, Set<Integer>> smallestEntry = randomVariableNeighbors.entrySet().stream().min((e1, e2) -> Integer.compare(e1.getValue().size(), e2.getValue().size())).get();
    if (smallestEntry.getValue().size() > MAX_NUM_ALLOWED_PARENTS_IN_CPT) {
        throw new IllegalStateException("Too large a CPT will need be generated as #parents=" + (smallestEntry.getValue().size() - 1));
    }
    Integer result = smallestEntry.getKey();
    return result;
}
Also used : MarkovNetwork(com.sri.ai.praise.core.representation.classbased.table.api.MarkovNetwork) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) Arrays(java.util.Arrays) Set(java.util.Set) ConditionalProbabilityTable(com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) Beta(com.google.common.annotations.Beta) LinkedHashMap(java.util.LinkedHashMap) List(java.util.List) Map(java.util.Map) Util(com.sri.ai.util.Util) FunctionTable(com.sri.ai.praise.core.representation.classbased.table.core.data.FunctionTable) LinkedHashSet(java.util.LinkedHashSet) Pair(com.sri.ai.util.base.Pair) Set(java.util.Set) LinkedHashSet(java.util.LinkedHashSet) FactorTable(com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) LinkedHashMap(java.util.LinkedHashMap)

Aggregations

FactorTable (com.sri.ai.praise.core.representation.classbased.table.core.data.markov.FactorTable)5 LinkedHashMap (java.util.LinkedHashMap)5 ArrayList (java.util.ArrayList)4 FunctionTable (com.sri.ai.praise.core.representation.classbased.table.core.data.FunctionTable)3 ConditionalProbabilityTable (com.sri.ai.praise.core.representation.classbased.table.core.data.bayes.ConditionalProbabilityTable)3 CartesianProductEnumeration (com.sri.ai.util.collect.CartesianProductEnumeration)3 Pair (com.sri.ai.util.base.Pair)2 LinkedHashSet (java.util.LinkedHashSet)2 List (java.util.List)2 Beta (com.google.common.annotations.Beta)1 MarkovNetwork (com.sri.ai.praise.core.representation.classbased.table.api.MarkovNetwork)1 Util (com.sri.ai.util.Util)1 Arrays (java.util.Arrays)1 Map (java.util.Map)1 Set (java.util.Set)1