Search in sources :

Example 1 with FunctionTable

use of com.sri.ai.praise.lang.grounded.common.FunctionTable in project aic-praise by aic-sri-international.

the class TranslationOfTableToInequalitiesTest method test.

@Test
public void test() {
    FunctionTable table;
    String expected;
    table = new FunctionTable(list(2), list(1.0, 1.0));
    expected = "1";
    run(table, expected);
    table = new FunctionTable(list(5), list(0.0, 1.0, 1.0, 0.0, 0.0));
    expected = "if (g0 = 0) or (g0 >= 3) then 0 else 1";
    run(table, expected);
    table = new FunctionTable(list(4, 3), list(0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0));
    expected = "if (g0 = 0) and (g1 >= 1) or (g0 >= 1) and (g0 <= 2) then 1 else 0";
    run(table, expected);
    table = new FunctionTable(list(1, 3), list(0.0, 1.0, 1.0));
    expected = "if (g1 >= 1) then 1 else 0";
    run(table, expected);
    table = new FunctionTable(list(4, 1), list(0.0, 1.0, 1.0, 0.0));
    expected = "if (g0 >= 1) and (g0 <= 2) then 1 else 0";
    run(table, expected);
    table = new FunctionTable(list(4, 3), list(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0));
    expected = "if (g0 >= 1) and (g0 <= 2) then 1 else 0";
    run(table, expected);
    table = new FunctionTable(list(3, 3), list(0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0));
    expected = "if (g0 = 0) or (g0 = 1) and (g1 = 0) or (g0 = 2) then 0 else 1";
    run(table, expected);
    table = new FunctionTable(list(3, 3), list(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0));
    expected = "0";
    run(table, expected);
    table = new FunctionTable(list(3, 3), list(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0));
    expected = "if g0 = 0 or g0 = 2 then 0 else 1";
    run(table, expected);
    table = new FunctionTable(list(10, 3), list(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0));
    expected = "if g0 >= 1 and g0 <= 8 then 1 else 0";
    run(table, expected);
    table = new FunctionTable(list(10, 3), list(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // <--- different
    0.0, // <--- different
    1.0, // <--- different
    1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0));
    expected = "if (g0 >= 1) and (g0 <= 4) or (g0 = 5) and (g1 >= 1) or (g0 >= 6) and (g0 <= 8) then 1 else 0";
    run(table, expected);
    table = new FunctionTable(list(2, 5), list(0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0));
    expected = "if (g0 = 0) and (g1 = 0) or " + "   (g0 = 0) and (g1 >= 3) or (g0 = 1) and (g1 = 0) or " + "   (g0 = 1) and (g1 >= 3) then 0 else 1";
    run(table, expected);
    table = new FunctionTable(list(2, 5), list(0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0));
    expected = "if (g0 = 0) and (g1 = 0) or " + "(g0 = 0) and (g1 >= 3) or (g0 = 1) and (g1 = 0) or (g0 = 1) and (g1 >= 3) then 0 else if (g0 = 0) and (g1 >= 1) and (g1 <= 2) or (g0 = 1) and (g1 = 1) then 1 else 2";
    run(table, expected);
    table = new FunctionTable(list(2, 5), list(1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0));
    expected = "if g0 = 1 then 2 else 1";
    run(table, expected);
    table = new FunctionTable(list(5, 4, 3), list(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0));
    expected = "if " + "(g0 = 0) or " + "(g0 = 1) and (g1 = 0) or " + "(g0 = 1) and (g1 = 3) or " + "(g0 = 2) and (g1 <= 1) or " + "(g0 = 2) and ((g1 = 2) and (g2 >= 1) or (g1 = 3)) or (g0 >= 3) " + "then 1 else 0";
    run(table, expected);
}
Also used : FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) Test(org.junit.Test)

Example 2 with FunctionTable

use of com.sri.ai.praise.lang.grounded.common.FunctionTable in project aic-praise by aic-sri-international.

the class XFormMarkovToBayes method sumOutChildAndCreateCPT.

private static Pair<FactorTable, ConditionalProbabilityTable> sumOutChildAndCreateCPT(MarkovNetwork markov, Integer c, FactorTable cFactor) {
    List<Integer> parentVarIdxs = new ArrayList<>(cFactor.getVariableIndexes());
    parentVarIdxs.remove(c);
    int cCardinality = markov.cardinality(c);
    int cIdx = cFactor.getVariableIndexes().indexOf(c);
    FactorTable summedOut = null;
    FunctionTable cptTable = null;
    if (parentVarIdxs.size() == 0) {
        // No summed out factor as no remaining elements but still need to ensure the cpt values are normalized			
        int idx = 0;
        List<Double> factorEntries = cFactor.getTable().getEntries();
        List<Double> cptEntries = new ArrayList<>(factorEntries.size());
        while (idx < factorEntries.size()) {
            double total = 0;
            for (int i = 0; i < cCardinality; i++) {
                total += factorEntries.get(idx + i);
            }
            for (int i = 0; i < cCardinality; i++) {
                cptEntries.add(factorEntries.get(idx + i) / total);
            }
            idx += cCardinality;
        }
        cptTable = new FunctionTable(Arrays.asList(cCardinality), cptEntries);
    } else {
        List<Integer> cptCardinalities = new ArrayList<>();
        List<Integer> parentVarCardinalities = new ArrayList<>();
        Map<Integer, Integer> parentCardIdxToCFactorIdx = new LinkedHashMap<>();
        int cFactorCardIdx = 0;
        for (Integer varIdx : cFactor.getVariableIndexes()) {
            if (!varIdx.equals(c)) {
                int card = markov.cardinality(varIdx);
                cptCardinalities.add(card);
                parentVarCardinalities.add(card);
                parentCardIdxToCFactorIdx.put(parentCardIdxToCFactorIdx.size(), cFactorCardIdx);
            }
            cFactorCardIdx++;
        }
        // The child index will be placed at the end of the table by convention
        cptCardinalities.add(markov.cardinality(c));
        List<Double> cptEntries = new ArrayList<>(FunctionTable.numEntriesFor(cptCardinalities));
        List<Double> parentSummedOutEntries = new ArrayList<>(FunctionTable.numEntriesFor(parentVarCardinalities));
        Map<Integer, Integer> assignmentMap = new LinkedHashMap<>();
        CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(parentVarCardinalities));
        while (cpe.hasMoreElements()) {
            List<Integer> parentValues = cpe.nextElement();
            assignmentMap.clear();
            for (int i = 0; i < parentValues.size(); i++) {
                assignmentMap.put(parentCardIdxToCFactorIdx.get(i), parentValues.get(i));
            }
            Double sum = cFactor.getTable().valueFor(assignmentMap);
            parentSummedOutEntries.add(sum);
            for (int i = 0; i < cCardinality; i++) {
                assignmentMap.put(cIdx, i);
                if (sum.equals(0.0)) {
                    // TODO - is this approach correct?						
                    // NOTE: This prevents invalid models being generated by assigning an impossibly small probability to an event that should never occur
                    cptEntries.add(Double.MIN_NORMAL);
                } else {
                    cptEntries.add(cFactor.getTable().valueFor(assignmentMap) / sum);
                }
            }
        }
        summedOut = new FactorTable(parentVarIdxs, new FunctionTable(parentVarCardinalities, parentSummedOutEntries));
        cptTable = new FunctionTable(cptCardinalities, cptEntries);
    }
    ConditionalProbabilityTable cpt = new ConditionalProbabilityTable(parentVarIdxs, c, cptTable);
    Pair<FactorTable, ConditionalProbabilityTable> result = new Pair<>(summedOut, cpt);
    return result;
}
Also used : ConditionalProbabilityTable(com.sri.ai.praise.lang.grounded.bayes.ConditionalProbabilityTable) ArrayList(java.util.ArrayList) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap) FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) FactorTable(com.sri.ai.praise.lang.grounded.markov.FactorTable) Pair(com.sri.ai.util.base.Pair)

Example 3 with FunctionTable

use of com.sri.ai.praise.lang.grounded.common.FunctionTable in project aic-praise by aic-sri-international.

the class UAIModelReader method readFunctionTables.

private static Map<Integer, FunctionTable> readFunctionTables(Preamble preamble, BufferedReader br) throws IOException {
    Map<Integer, FunctionTable> tableIdxToTable = new LinkedHashMap<>();
    // the first variable have the role of the 'most significant’ digit. 
    for (int t = 0; t < preamble.numTables(); t++) {
        FunctionTable functionTable = createFunctionTable(preamble.cardinalitiesForTable(t), br);
        tableIdxToTable.put(t, functionTable);
    }
    return tableIdxToTable;
}
Also used : FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) LinkedHashMap(java.util.LinkedHashMap)

Example 4 with FunctionTable

use of com.sri.ai.praise.lang.grounded.common.FunctionTable in project aic-praise by aic-sri-international.

the class UAIModelReader method createFunctionTable.

private static FunctionTable createFunctionTable(List<Integer> variableCardinalities, BufferedReader br) throws IOException {
    // For each factor table, first the number of entries is given (this should be equal to the product 
    // of the domain sizes of the variables in the scope). Then, one by one, separated by whitespace, 
    // the values for each assignment to the variables in the function's scope are enumerated.
    String[] entryLineEntries = split(readLine(br));
    int numberEntries = Integer.parseInt(entryLineEntries[0]);
    List<Double> entries = new ArrayList<Double>();
    // Handle table entries on the same line as the #entries information.
    if (entryLineEntries.length > 1) {
        for (int i = 1; i < entryLineEntries.length; i++) {
            entries.add(Double.parseDouble(entryLineEntries[i]));
        }
    }
    while (entries.size() < numberEntries) {
        String[] entryValues = split(readLine(br));
        for (String entry : entryValues) {
            entries.add(Double.parseDouble(entry));
        }
    }
    FunctionTable result = new FunctionTable(variableCardinalities, entries);
    return result;
}
Also used : FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) ArrayList(java.util.ArrayList)

Example 5 with FunctionTable

use of com.sri.ai.praise.lang.grounded.common.FunctionTable in project aic-praise by aic-sri-international.

the class XFormMarkovToBayes method newMegaFactor.

private static FactorTable newMegaFactor(MarkovNetwork markov, Set<Integer> varIdxs, List<FactorTable> factorsContainingC) {
    List<Integer> variableIndexes = new ArrayList<>(varIdxs);
    Map<Integer, Integer> varIdxToOffset = new LinkedHashMap<>();
    List<Integer> varCardinalities = new ArrayList<>();
    for (Integer varIdx : variableIndexes) {
        varIdxToOffset.put(varIdx, varIdxToOffset.size());
        varCardinalities.add(markov.cardinality(varIdx));
    }
    List<Double> entries = new ArrayList<>(FunctionTable.numEntriesFor(varCardinalities));
    List<Integer> varValues = new ArrayList<>();
    CartesianProductEnumeration<Integer> cpe = new CartesianProductEnumeration<>(FunctionTable.cardinalityValues(varCardinalities));
    while (cpe.hasMoreElements()) {
        List<Integer> assignments = cpe.nextElement();
        double product = 1;
        for (FactorTable ft : factorsContainingC) {
            varValues.clear();
            for (Integer varIdx : ft.getVariableIndexes()) {
                varValues.add(assignments.get(varIdxToOffset.get(varIdx)));
            }
            product *= ft.getTable().entryFor(varValues);
        }
        entries.add(product);
    }
    FactorTable result = new FactorTable(variableIndexes, new FunctionTable(varCardinalities, entries));
    return result;
}
Also used : FunctionTable(com.sri.ai.praise.lang.grounded.common.FunctionTable) ArrayList(java.util.ArrayList) FactorTable(com.sri.ai.praise.lang.grounded.markov.FactorTable) CartesianProductEnumeration(com.sri.ai.util.collect.CartesianProductEnumeration) LinkedHashMap(java.util.LinkedHashMap)

Aggregations

FunctionTable (com.sri.ai.praise.lang.grounded.common.FunctionTable)10 ArrayList (java.util.ArrayList)7 LinkedHashMap (java.util.LinkedHashMap)5 CartesianProductEnumeration (com.sri.ai.util.collect.CartesianProductEnumeration)4 Expression (com.sri.ai.expresso.api.Expression)3 Map (java.util.Map)3 Beta (com.google.common.annotations.Beta)2 MultiIndexQuantifierEliminator (com.sri.ai.grinder.sgdpllt.api.MultiIndexQuantifierEliminator)2 FactorTable (com.sri.ai.praise.lang.grounded.markov.FactorTable)2 Pair (com.sri.ai.util.base.Pair)2 BufferedReader (java.io.BufferedReader)2 List (java.util.List)2 Function (java.util.function.Function)2 Expressions (com.sri.ai.expresso.helper.Expressions)1 FALSE (com.sri.ai.expresso.helper.Expressions.FALSE)1 TRUE (com.sri.ai.expresso.helper.Expressions.TRUE)1 Expressions.apply (com.sri.ai.expresso.helper.Expressions.apply)1 Expressions.makeSymbol (com.sri.ai.expresso.helper.Expressions.makeSymbol)1 Context (com.sri.ai.grinder.sgdpllt.api.Context)1 Compilation (com.sri.ai.grinder.sgdpllt.application.Compilation)1