use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class TreeFeature method split.
/**
* Splits this tree feature into two.
*
* @param leftIndices The indices to go in the left branch.
* @param rightIndices The indices to go in the right branch.
* @param firstBuffer A buffer for temporary work.
* @param secondBuffer A buffer for temporary work.
* @return A pair of TreeFeatures, the first element is the left branch, the second the right.
*/
public Pair<TreeFeature, TreeFeature> split(int[] leftIndices, int[] rightIndices, IntArrayContainer firstBuffer, IntArrayContainer secondBuffer) {
if (!sorted) {
throw new IllegalStateException("TreeFeature must be sorted before split is called");
}
List<InvertedFeature> leftFeatures;
List<InvertedFeature> rightFeatures;
if (feature.size() == 1) {
double value = feature.get(0).value;
leftFeatures = Collections.singletonList(new InvertedFeature(value, leftIndices));
rightFeatures = Collections.singletonList(new InvertedFeature(value, rightIndices));
} else {
leftFeatures = new ArrayList<>();
rightFeatures = new ArrayList<>();
firstBuffer.fill(leftIndices);
for (InvertedFeature f : feature) {
// Check if we've exhausted all the left side indices
if (firstBuffer.size > 0) {
Pair<InvertedFeature, InvertedFeature> split = f.split(firstBuffer, secondBuffer);
IntArrayContainer tmp = secondBuffer;
secondBuffer = firstBuffer;
firstBuffer = tmp;
InvertedFeature left = split.getA();
InvertedFeature right = split.getB();
if (left != null) {
leftFeatures.add(left);
}
if (right != null) {
rightFeatures.add(right);
}
} else {
rightFeatures.add(f);
}
}
}
return new Pair<>(new TreeFeature(id, leftFeatures), new TreeFeature(id, rightFeatures));
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class BinaryCrossEntropy method lossAndGradient.
/**
* Returns a {@link Pair} of {@link Double} and {@link SGDVector} representing the loss
* and per label gradients respectively.
* <p>
* The prediction vector is transformed to produce the per label gradient and returned.
* @param truth The true label id
* @param prediction The prediction for each label id
* @return A Pair of the score and per label gradient.
*/
@Override
public Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) {
DenseVector labels, densePred;
if (truth instanceof SparseVector) {
labels = ((SparseVector) truth).densify();
} else {
labels = (DenseVector) truth;
}
if (prediction instanceof SparseVector) {
densePred = ((SparseVector) prediction).densify();
} else {
densePred = (DenseVector) prediction;
}
double loss = 0.0;
for (int i = 0; i < prediction.size(); i++) {
double label = labels.get(i);
double pred = densePred.get(i);
double yhat = SigmoidNormalizer.sigmoid(pred);
// numerically stable form of loss computation
loss += Math.max(pred, 0) - (pred * label) + Math.log1p(Math.exp(-Math.abs(pred)));
densePred.set(i, -(yhat - label));
}
return new Pair<>(loss, densePred);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class Hinge method lossAndGradient.
/**
* Returns a {@link Pair} of {@link Double} and {@link SGDVector} representing the loss
* and per label gradients respectively.
* @param truth The true label id.
* @param prediction The prediction for each label id.
* @return The loss and per label gradient.
*/
@Override
public Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) {
DenseVector labels, densePred;
if (truth instanceof SparseVector) {
labels = ((SparseVector) truth).densify();
} else {
labels = (DenseVector) truth;
}
if (prediction instanceof SparseVector) {
densePred = ((SparseVector) prediction).densify();
} else {
densePred = (DenseVector) prediction;
}
double loss = 0.0;
for (int i = 0; i < labels.size(); i++) {
double lbl = labels.get(i) == 0.0 ? -1 : 1.0;
double pred = densePred.get(i);
double score = lbl * pred;
if (score < margin) {
densePred.set(i, lbl);
} else {
densePred.set(i, 0.0);
}
loss += Math.max(0.0, margin - score);
}
return new Pair<>(loss, densePred);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class NeighbourQueryTestHelper method neighboursQueryOneExclusive.
static void neighboursQueryOneExclusive(NeighboursQueryFactory nqf) {
SGDVector[] data = getShuffledTestDataVectorArray();
// This point is excluded from the set of points being queried
SGDVector vector = get2DPoint(5.22, 5.25);
NeighboursQuery nq = nqf.createNeighboursQuery(data);
List<Pair<Integer, Double>> indexDistancePairList = nq.query(vector, 3);
// This helper uses k = 3. These are the expected neighboring points.
// The point itself is returned as the first neighbouring point, so we don't check that.
double[] firstExpectedPoint0 = { 5.21, 5.28 };
double[] firstExpectedPoint1 = { 5.23, 5.02 };
double[] firstExpectedPoint2 = { 4.95, 5.25 };
assertNeighbourPoints(data, indexDistancePairList, firstExpectedPoint0, firstExpectedPoint1, firstExpectedPoint2);
// This helper uses k = 3. These are the expected neighboring distances.
double expectedDistance0 = 0.031622776601683965;
double expectedDistance1 = 0.23021728866442723;
double expectedDistance2 = 0.2699999999999996;
assertNeighbourDistances(indexDistancePairList, expectedDistance0, expectedDistance1, expectedDistance2);
}
use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.
the class NeighbourQueryTestHelper method neighboursQueryOneInclusive.
static void neighboursQueryOneInclusive(NeighboursQueryFactory nqf) {
SGDVector[] data = getTestDataVectorArray();
// This point is included in the set of points being queried
SGDVector vector = get2DPoint(5.21, 5.28);
NeighboursQuery nq = nqf.createNeighboursQuery(data);
List<Pair<Integer, Double>> indexDistancePairList = nq.query(vector, 3);
// This helper uses k = 3. These are the expected neighboring points.
// The point itself is returned as the first neighbouring point, so we don't check that.
double[] firstExpectedPoint0 = { 5.21, 5.28 };
double[] firstExpectedPoint1 = { 5.23, 5.02 };
double[] firstExpectedPoint2 = { 4.95, 5.25 };
assertNeighbourPoints(data, indexDistancePairList, firstExpectedPoint0, firstExpectedPoint1, firstExpectedPoint2);
// This helper uses k = 3. These are the expected neighboring distances.
double expectedDistance0 = 0.0;
double expectedDistance1 = 0.2607680962081067;
double expectedDistance2 = 0.26172504656604784;
assertNeighbourDistances(indexDistancePairList, expectedDistance0, expectedDistance1, expectedDistance2);
}
Aggregations