use of hex.genmodel.utils.GenmodelBitSet in project h2o-3 by h2oai.
the class SharedTreeMojoModel method scoreTree1.
/**
* SET IN STONE FOR MOJO VERSION "1.10" - DO NOT CHANGE
* @param tree
* @param row
* @param nclasses
* @param computeLeafAssignment
* @return
*/
// Complains that the code is too complex. Well duh!
@SuppressWarnings("ConstantConditions")
public static double scoreTree1(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
ByteBufferWrapper ab = new ByteBufferWrapper(tree);
GenmodelBitSet bs = null;
long bitsRight = 0;
int level = 0;
while (true) {
int nodeType = ab.get1U();
int colId = ab.get2();
if (colId == 65535)
return ab.get4f();
int naSplitDir = ab.get1U();
boolean naVsRest = naSplitDir == NsdNaVsRest;
boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
int lmask = (nodeType & 51);
// Can be one of 0, 8, 12
int equal = (nodeType & 12);
// no longer supported
assert equal != 4;
float splitVal = -1;
if (!naVsRest) {
// Extract value or group to split on
if (equal == 0) {
// Standard float-compare test (either < or ==)
// Get the float to compare
splitVal = ab.get4f();
} else {
// Bitset test
if (bs == null)
bs = new GenmodelBitSet(0);
if (equal == 8)
bs.fill2(tree, ab);
else
bs.fill3_1(tree, ab);
}
}
double d = row[colId];
if (Double.isNaN(d) || (equal != 0 && bs != null && !bs.isInRange((int) d)) ? !leftward : !naVsRest && (equal == 0 ? d >= splitVal : bs.contains((int) d))) {
// go RIGHT
switch(lmask) {
case 0:
ab.skip(ab.get1U());
break;
case 1:
ab.skip(ab.get2());
break;
case 2:
ab.skip(ab.get3());
break;
case 3:
ab.skip(ab.get4());
break;
// Small leaf
case 16:
ab.skip(nclasses < 256 ? 1 : 2);
break;
// skip the prediction
case 48:
ab.skip(4);
break;
default:
assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
}
if (computeLeafAssignment && level < 64)
bitsRight |= 1 << level;
// Replace leftmask with the rightmask
lmask = (nodeType & 0xC0) >> 2;
} else {
// go LEFT
if (lmask <= 3)
ab.skip(lmask + 1);
}
level++;
if ((lmask & 16) != 0) {
if (computeLeafAssignment) {
// mark the end of the tree
bitsRight |= 1 << level;
return Double.longBitsToDouble(bitsRight);
} else {
return ab.get4f();
}
}
}
}
use of hex.genmodel.utils.GenmodelBitSet in project h2o-3 by h2oai.
the class SharedTreeMojoModel method scoreTree.
/**
* Highly efficient (critical path) tree scoring
*
* Given a tree (in the form of a byte array) and the row of input data, compute either this tree's
* predicted value when `computeLeafAssignment` is false, or the the decision path within the tree (but no more
* than 64 levels) when `computeLeafAssignment` is true.
*
* Note: this function is also used from the `hex.tree.CompressedTree` class in `h2o-algos` project.
*/
// Complains that the code is too complex. Well duh!
@SuppressWarnings("ConstantConditions")
public static double scoreTree(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment, String[][] domains) {
ByteBufferWrapper ab = new ByteBufferWrapper(tree);
GenmodelBitSet bs = null;
long bitsRight = 0;
int level = 0;
while (true) {
int nodeType = ab.get1U();
int colId = ab.get2();
if (colId == 65535)
return ab.get4f();
int naSplitDir = ab.get1U();
boolean naVsRest = naSplitDir == NsdNaVsRest;
boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
int lmask = (nodeType & 51);
// Can be one of 0, 8, 12
int equal = (nodeType & 12);
// no longer supported
assert equal != 4;
float splitVal = -1;
if (!naVsRest) {
// Extract value or group to split on
if (equal == 0) {
// Standard float-compare test (either < or ==)
// Get the float to compare
splitVal = ab.get4f();
} else {
// Bitset test
if (bs == null)
bs = new GenmodelBitSet(0);
if (equal == 8)
bs.fill2(tree, ab);
else
bs.fill3(tree, ab);
}
}
// This logic:
//
// double d = row[colId];
// if (Double.isNaN(d) || ( equal != 0 && bs != null && !bs.isInRange((int)d) ) || (domains != null && domains[colId] != null && domains[colId].length <= (int)d)
// ? !leftward : !naVsRest && (equal == 0? d >= splitVal : bs.contains((int)d))) {
// Really does this:
//
// if (value is NaN or value is not in the range of the bitset or is outside the domain map length (but an integer) ) {
// if (leftward) {
// go left
// }
// else {
// go right
// }
// }
// else {
// if (naVsRest) {
// go left
// }
// else {
// if (numeric) {
// if (value < split value) {
// go left
// }
// else {
// go right
// }
// }
// else {
// if (value not in bitset) {
// go left
// }
// else {
// go right
// }
// }
// }
// }
double d = row[colId];
if (Double.isNaN(d) || (equal != 0 && bs != null && !bs.isInRange((int) d)) || (domains != null && domains[colId] != null && domains[colId].length <= (int) d) ? !leftward : !naVsRest && (equal == 0 ? d >= splitVal : bs.contains((int) d))) {
// go RIGHT
switch(lmask) {
case 0:
ab.skip(ab.get1U());
break;
case 1:
ab.skip(ab.get2());
break;
case 2:
ab.skip(ab.get3());
break;
case 3:
ab.skip(ab.get4());
break;
// Small leaf
case 16:
ab.skip(nclasses < 256 ? 1 : 2);
break;
// skip the prediction
case 48:
ab.skip(4);
break;
default:
assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
}
if (computeLeafAssignment && level < 64)
bitsRight |= 1 << level;
// Replace leftmask with the rightmask
lmask = (nodeType & 0xC0) >> 2;
} else {
// go LEFT
if (lmask <= 3)
ab.skip(lmask + 1);
}
level++;
if ((lmask & 16) != 0) {
if (computeLeafAssignment) {
// mark the end of the tree
bitsRight |= 1 << level;
return Double.longBitsToDouble(bitsRight);
} else {
return ab.get4f();
}
}
}
}
use of hex.genmodel.utils.GenmodelBitSet in project h2o-3 by h2oai.
the class SharedTreeNode method calculateChildInclusiveLevels.
/**
* Calculate the set of levels that flow through to a child.
* @param includeAllLevels naVsRest dictates include all (inherited) levels
* @param discardAllLevels naVsRest dictates discard all levels
* @param nodeBitsetDoesContain true if the GenmodelBitset from the compressed_tree
* @return Calculated set of levels
*/
private BitSet calculateChildInclusiveLevels(boolean includeAllLevels, boolean discardAllLevels, boolean nodeBitsetDoesContain) {
BitSet inheritedInclusiveLevels = findInclusiveLevels(colId);
BitSet childInclusiveLevels = new BitSet();
for (int i = 0; i < domainValues.length; i++) {
// Calculate whether this level should flow into this child node.
boolean includeThisLevel = false;
{
if (discardAllLevels) {
includeThisLevel = false;
} else if (includeAllLevels) {
includeThisLevel = calculateIncludeThisLevel(inheritedInclusiveLevels, i);
} else if (bs.isInRange(i) && bs.contains(i) == nodeBitsetDoesContain) {
includeThisLevel = calculateIncludeThisLevel(inheritedInclusiveLevels, i);
}
}
if (includeThisLevel) {
childInclusiveLevels.set(i);
}
}
return childInclusiveLevels;
}
use of hex.genmodel.utils.GenmodelBitSet in project h2o-3 by h2oai.
the class SharedTreeMojoModel method scoreTree0.
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
// DO NOT CHANGE THE CODE BELOW THIS LINE
/////////////////////////////////////////////////////
/**
* SET IN STONE FOR MOJO VERSION "1.00" - DO NOT CHANGE
* @param tree
* @param row
* @param nclasses
* @param computeLeafAssignment
* @return
*/
// Complains that the code is too complex. Well duh!
@SuppressWarnings("ConstantConditions")
public static double scoreTree0(byte[] tree, double[] row, int nclasses, boolean computeLeafAssignment) {
ByteBufferWrapper ab = new ByteBufferWrapper(tree);
// Lazily set on hitting first group test
GenmodelBitSet bs = null;
long bitsRight = 0;
int level = 0;
while (true) {
int nodeType = ab.get1U();
int colId = ab.get2();
if (colId == 65535)
return ab.get4f();
int naSplitDir = ab.get1U();
boolean naVsRest = naSplitDir == NsdNaVsRest;
boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
int lmask = (nodeType & 51);
// Can be one of 0, 8, 12
int equal = (nodeType & 12);
// no longer supported
assert equal != 4;
float splitVal = -1;
if (!naVsRest) {
// Extract value or group to split on
if (equal == 0) {
// Standard float-compare test (either < or ==)
// Get the float to compare
splitVal = ab.get4f();
} else {
// Bitset test
if (bs == null)
bs = new GenmodelBitSet(0);
if (equal == 8)
bs.fill2(tree, ab);
else
bs.fill3_1(tree, ab);
}
}
double d = row[colId];
if (Double.isNaN(d) ? !leftward : !naVsRest && (equal == 0 ? d >= splitVal : bs.contains0((int) d))) {
// go RIGHT
switch(lmask) {
case 0:
ab.skip(ab.get1U());
break;
case 1:
ab.skip(ab.get2());
break;
case 2:
ab.skip(ab.get3());
break;
case 3:
ab.skip(ab.get4());
break;
// Small leaf
case 16:
ab.skip(nclasses < 256 ? 1 : 2);
break;
// skip the prediction
case 48:
ab.skip(4);
break;
default:
assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
}
if (computeLeafAssignment && level < 64)
bitsRight |= 1 << level;
// Replace leftmask with the rightmask
lmask = (nodeType & 0xC0) >> 2;
} else {
// go LEFT
if (lmask <= 3)
ab.skip(lmask + 1);
}
level++;
if ((lmask & 16) != 0) {
if (computeLeafAssignment) {
// mark the end of the tree
bitsRight |= 1 << level;
return Double.longBitsToDouble(bitsRight);
} else {
return ab.get4f();
}
}
}
}
use of hex.genmodel.utils.GenmodelBitSet in project h2o-3 by h2oai.
the class SharedTreeMojoModel method computeTreeGraph.
//------------------------------------------------------------------------------------------------------------------
// Computing a Tree Graph
//------------------------------------------------------------------------------------------------------------------
private void computeTreeGraph(SharedTreeSubgraph sg, SharedTreeNode node, byte[] tree, ByteBufferWrapper ab, HashMap<Integer, AuxInfo> auxMap, int nclasses) {
int nodeType = ab.get1U();
int colId = ab.get2();
if (colId == 65535) {
float leafValue = ab.get4f();
node.setPredValue(leafValue);
return;
}
String colName = getNames()[colId];
node.setCol(colId, colName);
int naSplitDir = ab.get1U();
boolean naVsRest = naSplitDir == NsdNaVsRest;
boolean leftward = naSplitDir == NsdNaLeft || naSplitDir == NsdLeft;
node.setLeftward(leftward);
node.setNaVsRest(naVsRest);
int lmask = (nodeType & 51);
// Can be one of 0, 8, 12
int equal = (nodeType & 12);
// no longer supported
assert equal != 4;
if (!naVsRest) {
// Extract value or group to split on
if (equal == 0) {
// Standard float-compare test (either < or ==)
// Get the float to compare
float splitVal = ab.get4f();
node.setSplitValue(splitVal);
} else {
// Bitset test
GenmodelBitSet bs = new GenmodelBitSet(0);
if (equal == 8)
bs.fill2(tree, ab);
else
bs.fill3(tree, ab);
node.setBitset(getDomainValues(colId), bs);
}
}
AuxInfo auxInfo = auxMap.get(node.getNodeNumber());
// go RIGHT
{
ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
ab2.skip(ab.position());
switch(lmask) {
case 0:
ab2.skip(ab2.get1U());
break;
case 1:
ab2.skip(ab2.get2());
break;
case 2:
ab2.skip(ab2.get3());
break;
case 3:
ab2.skip(ab2.get4());
break;
case 16:
ab2.skip(nclasses < 256 ? 1 : 2);
// Small leaf
break;
case 48:
ab2.skip(4);
// skip the prediction
break;
default:
assert false : "illegal lmask value " + lmask + " in tree " + Arrays.toString(tree);
}
// Replace leftmask with the rightmask
int lmask2 = (nodeType & 0xC0) >> 2;
SharedTreeNode newNode = sg.makeRightChildNode(node);
newNode.setWeight(auxInfo.weightR);
newNode.setNodeNumber(auxInfo.nidR);
newNode.setPredValue(auxInfo.predR);
newNode.setSquaredError(auxInfo.sqErrR);
if ((lmask2 & 16) != 0) {
float leafValue = ab2.get4f();
newNode.setPredValue(leafValue);
auxInfo.predR = leafValue;
} else {
computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses);
}
}
// go LEFT
{
ByteBufferWrapper ab2 = new ByteBufferWrapper(tree);
ab2.skip(ab.position());
if (lmask <= 3)
ab2.skip(lmask + 1);
SharedTreeNode newNode = sg.makeLeftChildNode(node);
newNode.setWeight(auxInfo.weightL);
newNode.setNodeNumber(auxInfo.nidL);
newNode.setPredValue(auxInfo.predL);
newNode.setSquaredError(auxInfo.sqErrL);
if ((lmask & 16) != 0) {
float leafValue = ab2.get4f();
newNode.setPredValue(leafValue);
auxInfo.predL = leafValue;
} else {
computeTreeGraph(sg, newNode, tree, ab2, auxMap, nclasses);
}
}
if (node.getNodeNumber() == 0) {
float p = (float) (((double) auxInfo.predL * (double) auxInfo.weightL + (double) auxInfo.predR * (double) auxInfo.weightR) / ((double) auxInfo.weightL + (double) auxInfo.weightR));
if (Math.abs(p) < 1e-7)
p = 0;
node.setPredValue(p);
node.setSquaredError(auxInfo.sqErrR + auxInfo.sqErrL);
node.setWeight(auxInfo.weightL + auxInfo.weightR);
}
checkConsistency(auxInfo, node);
}
Aggregations