Search in sources :

Example 1 with ByteBufferWrapper

use of hex.genmodel.utils.ByteBufferWrapper in project h2o-3 by h2oai.

the class SharedTreeMojoModel method scoreTree1.

/**
   * SET IN STONE FOR MOJO VERSION "1.10" - DO NOT CHANGE
   * @param tree
   * @param row
   * @param nclasses
   * @param computeLeafAssignment
   * @return
   */
// Complains that the code is too complex. Well duh!
@SuppressWarnings("ConstantConditions")
public static double scoreTree1(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
    ByteBufferWrapper ab = new ByteBufferWrapper(tree);
    GenmodelBitSet bs = null;
    long bitsRight = 0;
    int level = 0;
    while (true) {
        int nodeType = ab.get1U();
        int colId = ab.get2();
        if (colId == 65535)
            return ab.get4f();
        int naSplitDir = ab.get1U();
        boolean naVsRest = naSplitDir == NsdNaVsRest;
        boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
        int lmask = (nodeType & 51);
        // Can be one of 0, 8, 12
        int equal = (nodeType & 12);
        // no longer supported
        assert equal != 4;
        float splitVal = -1;
        if (!naVsRest) {
            // Extract value or group to split on
            if (equal == 0) {
                // Standard float-compare test (either < or ==)
                // Get the float to compare
                splitVal = ab.get4f();
            } else {
                // Bitset test
                if (bs == null)
                    bs = new GenmodelBitSet(0);
                if (equal == 8)
                    bs.fill2(tree, ab);
                else
                    bs.fill3_1(tree, ab);
            }
        }
        double d = row[colId];
        if (Double.isNaN(d) || (equal != 0 && bs != null && !bs.isInRange((int) d)) ? !leftward : !naVsRest && (equal == 0 ? d >= splitVal : bs.contains((int) d))) {
            // go RIGHT
            switch(lmask) {
                case 0:
                    ab.skip(ab.get1U());
                    break;
                case 1:
                    ab.skip(ab.get2());
                    break;
                case 2:
                    ab.skip(ab.get3());
                    break;
                case 3:
                    ab.skip(ab.get4());
                    break;
                // Small leaf
                case 16:
                    ab.skip(nclasses < 256 ? 1 : 2);
                    break;
                // skip the prediction
                case 48:
                    ab.skip(4);
                    break;
                default:
                    assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
            }
            if (computeLeafAssignment && level < 64)
                bitsRight |= 1 << level;
            // Replace leftmask with the rightmask
            lmask = (nodeType & 0xC0) >> 2;
        } else {
            // go LEFT
            if (lmask <= 3)
                ab.skip(lmask + 1);
        }
        level++;
        if ((lmask & 16) != 0) {
            if (computeLeafAssignment) {
                // mark the end of the tree
                bitsRight |= 1 << level;
                return Double.longBitsToDouble(bitsRight);
            } else {
                return ab.get4f();
            }
        }
    }
}
Also used : ByteBufferWrapper(hex.genmodel.utils.ByteBufferWrapper) GenmodelBitSet(hex.genmodel.utils.GenmodelBitSet)

Example 2 with ByteBufferWrapper

use of hex.genmodel.utils.ByteBufferWrapper in project h2o-3 by h2oai.

the class SharedTreeMojoModel method scoreTree.

/**
   * Highly efficient (critical path) tree scoring
   *
   * Given a tree (in the form of a byte array) and the row of input data, compute either this tree's
   * predicted value when `computeLeafAssignment` is false, or the the decision path within the tree (but no more
   * than 64 levels) when `computeLeafAssignment` is true.
   *
   * Note: this function is also used from the `hex.tree.CompressedTree` class in `h2o-algos` project.
   */
// Complains that the code is too complex. Well duh!
@SuppressWarnings("ConstantConditions")
public static double scoreTree(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment, String[][] domains) {
    ByteBufferWrapper ab = new ByteBufferWrapper(tree);
    GenmodelBitSet bs = null;
    long bitsRight = 0;
    int level = 0;
    while (true) {
        int nodeType = ab.get1U();
        int colId = ab.get2();
        if (colId == 65535)
            return ab.get4f();
        int naSplitDir = ab.get1U();
        boolean naVsRest = naSplitDir == NsdNaVsRest;
        boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
        int lmask = (nodeType & 51);
        // Can be one of 0, 8, 12
        int equal = (nodeType & 12);
        // no longer supported
        assert equal != 4;
        float splitVal = -1;
        if (!naVsRest) {
            // Extract value or group to split on
            if (equal == 0) {
                // Standard float-compare test (either < or ==)
                // Get the float to compare
                splitVal = ab.get4f();
            } else {
                // Bitset test
                if (bs == null)
                    bs = new GenmodelBitSet(0);
                if (equal == 8)
                    bs.fill2(tree, ab);
                else
                    bs.fill3(tree, ab);
            }
        }
        // This logic:
        //
        //        double d = row[colId];
        //        if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) || (domains != null && domains[colId] != null && domains[colId].length <= (int)d)
        //              ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) {
        // Really does this:
        //
        //        if (value is NaN or value is not in the range of the bitset or is outside the domain map length (but an integer) ) {
        //            if (leftward) {
        //                go left
        //            }
        //            else {
        //                go right
        //            }
        //        }
        //        else {
        //            if (naVsRest) {
        //                go left
        //            }
        //            else {
        //                if (numeric) {
        //                    if (value < split value) {
        //                        go left
        //                    }
        //                    else {
        //                        go right
        //                    }
        //                }
        //                else {
        //                    if (value not in bitset) {
        //                        go left
        //                    }
        //                    else {
        //                        go right
        //                    }
        //                }
        //            }
        //        }
        double d = row[colId];
        if (Double.isNaN(d) || (equal != 0 && bs != null && !bs.isInRange((int) d)) || (domains != null && domains[colId] != null && domains[colId].length <= (int) d) ? !leftward : !naVsRest && (equal == 0 ? d >= splitVal : bs.contains((int) d))) {
            // go RIGHT
            switch(lmask) {
                case 0:
                    ab.skip(ab.get1U());
                    break;
                case 1:
                    ab.skip(ab.get2());
                    break;
                case 2:
                    ab.skip(ab.get3());
                    break;
                case 3:
                    ab.skip(ab.get4());
                    break;
                // Small leaf
                case 16:
                    ab.skip(nclasses < 256 ? 1 : 2);
                    break;
                // skip the prediction
                case 48:
                    ab.skip(4);
                    break;
                default:
                    assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
            }
            if (computeLeafAssignment && level < 64)
                bitsRight |= 1 << level;
            // Replace leftmask with the rightmask
            lmask = (nodeType & 0xC0) >> 2;
        } else {
            // go LEFT
            if (lmask <= 3)
                ab.skip(lmask + 1);
        }
        level++;
        if ((lmask & 16) != 0) {
            if (computeLeafAssignment) {
                // mark the end of the tree
                bitsRight |= 1 << level;
                return Double.longBitsToDouble(bitsRight);
            } else {
                return ab.get4f();
            }
        }
    }
}
Also used : ByteBufferWrapper(hex.genmodel.utils.ByteBufferWrapper) GenmodelBitSet(hex.genmodel.utils.GenmodelBitSet)

Example 3 with ByteBufferWrapper

use of hex.genmodel.utils.ByteBufferWrapper in project h2o-3 by h2oai.

the class SharedTreeMojoModel method _computeGraph.

/**
     * Compute a graph of the forest.
     *
     * @return A graph of the forest.
     */
public SharedTreeGraph _computeGraph(int treeToPrint) {
    SharedTreeGraph g = new SharedTreeGraph();
    if (treeToPrint >= _ntree_groups) {
        throw new IllegalArgumentException("Tree " + treeToPrint + " does not exist (max " + _ntree_groups + ")");
    }
    int j;
    if (treeToPrint >= 0) {
        j = treeToPrint;
    } else {
        j = 0;
    }
    for (; j < _ntree_groups; j++) {
        for (int i = 0; i < _ntrees_per_group; i++) {
            String className = "";
            {
                String[] domainValues = getDomainValues(getResponseIdx());
                if (domainValues != null) {
                    className = ", Class " + domainValues[i];
                }
            }
            int itree = treeIndex(j, i);
            SharedTreeSubgraph sg = g.makeSubgraph("Tree " + j + className);
            SharedTreeNode node = sg.makeRootNode();
            node.setSquaredError(Float.NaN);
            node.setPredValue(Float.NaN);
            byte[] tree = _compressed_trees[itree];
            ByteBufferWrapper ab = new ByteBufferWrapper(tree);
            ByteBufferWrapper abAux = new ByteBufferWrapper(_compressed_trees_aux[itree]);
            HashMap<Integer, AuxInfo> auxMap = new HashMap<>();
            while (abAux.hasRemaining()) {
                AuxInfo auxInfo = new AuxInfo(abAux);
                auxMap.put(auxInfo.nid, auxInfo);
            }
            computeTreeGraph(sg, node, tree, ab, auxMap, _nclasses);
        }
        if (treeToPrint >= 0) {
            break;
        }
    }
    return g;
}
Also used : HashMap(java.util.HashMap) ByteBufferWrapper(hex.genmodel.utils.ByteBufferWrapper)

Example 4 with ByteBufferWrapper

use of hex.genmodel.utils.ByteBufferWrapper in project h2o-3 by h2oai.

the class SharedTreeMojoModel method scoreTree0.

// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
/////////////////////////////////////////////////////
/**
   * SET IN STONE FOR MOJO VERSION "1.00" - DO NOT CHANGE
   * @param tree
   * @param row
   * @param nclasses
   * @param computeLeafAssignment
   * @return
   */
// Complains that the code is too complex. Well duh!
@SuppressWarnings("ConstantConditions")
public static double scoreTree0(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
    ByteBufferWrapper ab = new ByteBufferWrapper(tree);
    // Lazily set on hitting first group test
    GenmodelBitSet bs = null;
    long bitsRight = 0;
    int level = 0;
    while (true) {
        int nodeType = ab.get1U();
        int colId = ab.get2();
        if (colId == 65535)
            return ab.get4f();
        int naSplitDir = ab.get1U();
        boolean naVsRest = naSplitDir == NsdNaVsRest;
        boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
        int lmask = (nodeType & 51);
        // Can be one of 0, 8, 12
        int equal = (nodeType & 12);
        // no longer supported
        assert equal != 4;
        float splitVal = -1;
        if (!naVsRest) {
            // Extract value or group to split on
            if (equal == 0) {
                // Standard float-compare test (either < or ==)
                // Get the float to compare
                splitVal = ab.get4f();
            } else {
                // Bitset test
                if (bs == null)
                    bs = new GenmodelBitSet(0);
                if (equal == 8)
                    bs.fill2(tree, ab);
                else
                    bs.fill3_1(tree, ab);
            }
        }
        double d = row[colId];
        if (Double.isNaN(d) ? !leftward : !naVsRest && (equal == 0 ? d >= splitVal : bs.contains0((int) d))) {
            // go RIGHT
            switch(lmask) {
                case 0:
                    ab.skip(ab.get1U());
                    break;
                case 1:
                    ab.skip(ab.get2());
                    break;
                case 2:
                    ab.skip(ab.get3());
                    break;
                case 3:
                    ab.skip(ab.get4());
                    break;
                // Small leaf
                case 16:
                    ab.skip(nclasses < 256 ? 1 : 2);
                    break;
                // skip the prediction
                case 48:
                    ab.skip(4);
                    break;
                default:
                    assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
            }
            if (computeLeafAssignment && level < 64)
                bitsRight |= 1 << level;
            // Replace leftmask with the rightmask
            lmask = (nodeType & 0xC0) >> 2;
        } else {
            // go LEFT
            if (lmask <= 3)
                ab.skip(lmask + 1);
        }
        level++;
        if ((lmask & 16) != 0) {
            if (computeLeafAssignment) {
                // mark the end of the tree
                bitsRight |= 1 << level;
                return Double.longBitsToDouble(bitsRight);
            } else {
                return ab.get4f();
            }
        }
    }
}
Also used : ByteBufferWrapper(hex.genmodel.utils.ByteBufferWrapper) GenmodelBitSet(hex.genmodel.utils.GenmodelBitSet)

Example 5 with ByteBufferWrapper

use of hex.genmodel.utils.ByteBufferWrapper in project h2o-3 by h2oai.

the class SharedTreeMojoModel method computeTreeGraph.

//------------------------------------------------------------------------------------------------------------------
// Computing a Tree Graph
//------------------------------------------------------------------------------------------------------------------
private void computeTreeGraph(SharedTreeSubgraph sg, SharedTreeNode node, byte[] tree, ByteBufferWrapper ab, HashMap<Integer, AuxInfo> auxMap, int nclasses) {
    int nodeType = ab.get1U();
    int colId = ab.get2();
    if (colId == 65535) {
        float leafValue = ab.get4f();
        node.setPredValue(leafValue);
        return;
    }
    String colName = getNames()[colId];
    node.setCol(colId, colName);
    int naSplitDir = ab.get1U();
    boolean naVsRest = naSplitDir == NsdNaVsRest;
    boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
    node.setLeftward(leftward);
    node.setNaVsRest(naVsRest);
    int lmask = (nodeType & 51);
    // Can be one of 0, 8, 12
    int equal = (nodeType & 12);
    // no longer supported
    assert equal != 4;
    if (!naVsRest) {
        // Extract value or group to split on
        if (equal == 0) {
            // Standard float-compare test (either < or ==)
            // Get the float to compare
            float splitVal = ab.get4f();
            node.setSplitValue(splitVal);
        } else {
            // Bitset test
            GenmodelBitSet bs = new GenmodelBitSet(0);
            if (equal == 8)
                bs.fill2(tree, ab);
            else
                bs.fill3(tree, ab);
            node.setBitset(getDomainValues(colId), bs);
        }
    }
    AuxInfo auxInfo = auxMap.get(node.getNodeNumber());
    // go RIGHT
    {
        ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
        ab2.skip(ab.position());
        switch(lmask) {
            case 0:
                ab2.skip(ab2.get1U());
                break;
            case 1:
                ab2.skip(ab2.get2());
                break;
            case 2:
                ab2.skip(ab2.get3());
                break;
            case 3:
                ab2.skip(ab2.get4());
                break;
            case 16:
                ab2.skip(nclasses < 256 ? 1 : 2);
                // Small leaf
                break;
            case 48:
                ab2.skip(4);
                // skip the prediction
                break;
            default:
                assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
        }
        // Replace leftmask with the rightmask
        int lmask2 = (nodeType & 0xC0) >> 2;
        SharedTreeNode newNode = sg.makeRightChildNode(node);
        newNode.setWeight(auxInfo.weightR);
        newNode.setNodeNumber(auxInfo.nidR);
        newNode.setPredValue(auxInfo.predR);
        newNode.setSquaredError(auxInfo.sqErrR);
        if ((lmask2 & 16) != 0) {
            float leafValue = ab2.get4f();
            newNode.setPredValue(leafValue);
            auxInfo.predR = leafValue;
        } else {
            computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses);
        }
    }
    // go LEFT
    {
        ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
        ab2.skip(ab.position());
        if (lmask <= 3)
            ab2.skip(lmask + 1);
        SharedTreeNode newNode = sg.makeLeftChildNode(node);
        newNode.setWeight(auxInfo.weightL);
        newNode.setNodeNumber(auxInfo.nidL);
        newNode.setPredValue(auxInfo.predL);
        newNode.setSquaredError(auxInfo.sqErrL);
        if ((lmask & 16) != 0) {
            float leafValue = ab2.get4f();
            newNode.setPredValue(leafValue);
            auxInfo.predL = leafValue;
        } else {
            computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses);
        }
    }
    if (node.getNodeNumber() == 0) {
        float p = (float) (((double) auxInfo.predL * (double) auxInfo.weightL + (double) auxInfo.predR * (double) auxInfo.weightR) / ((double) auxInfo.weightL + (double) auxInfo.weightR));
        if (Math.abs(p) < 1e-7)
            p = 0;
        node.setPredValue(p);
        node.setSquaredError(auxInfo.sqErrR + auxInfo.sqErrL);
        node.setWeight(auxInfo.weightL + auxInfo.weightR);
    }
    checkConsistency(auxInfo, node);
}
Also used : ByteBufferWrapper(hex.genmodel.utils.ByteBufferWrapper) GenmodelBitSet(hex.genmodel.utils.GenmodelBitSet)

Aggregations

ByteBufferWrapper (hex.genmodel.utils.ByteBufferWrapper)5 GenmodelBitSet (hex.genmodel.utils.GenmodelBitSet)4 HashMap (java.util.HashMap)1