Search in sources :

Example 16 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class TreeFeature method split.

/**
 * Splits this tree feature into two.
 *
 * @param leftIndices The indices to go in the left branch.
 * @param rightIndices The indices to go in the right branch.
 * @param firstBuffer A buffer for temporary work.
 * @param secondBuffer A buffer for temporary work.
 * @return A pair of TreeFeatures, the first element is the left branch, the second the right.
 */
public Pair<TreeFeature, TreeFeature> split(int[] leftIndices, int[] rightIndices, IntArrayContainer firstBuffer, IntArrayContainer secondBuffer) {
    if (!sorted) {
        throw new IllegalStateException("TreeFeature must be sorted before split is called");
    }
    List<InvertedFeature> leftFeatures;
    List<InvertedFeature> rightFeatures;
    if (feature.size() == 1) {
        double value = feature.get(0).value;
        leftFeatures = Collections.singletonList(new InvertedFeature(value, leftIndices));
        rightFeatures = Collections.singletonList(new InvertedFeature(value, rightIndices));
    } else {
        leftFeatures = new ArrayList<>();
        rightFeatures = new ArrayList<>();
        firstBuffer.fill(leftIndices);
        for (InvertedFeature f : feature) {
            // Check if we've exhausted all the left side indices
            if (firstBuffer.size > 0) {
                Pair<InvertedFeature, InvertedFeature> split = f.split(firstBuffer, secondBuffer);
                IntArrayContainer tmp = secondBuffer;
                secondBuffer = firstBuffer;
                firstBuffer = tmp;
                InvertedFeature left = split.getA();
                InvertedFeature right = split.getB();
                if (left != null) {
                    leftFeatures.add(left);
                }
                if (right != null) {
                    rightFeatures.add(right);
                }
            } else {
                rightFeatures.add(f);
            }
        }
    }
    return new Pair<>(new TreeFeature(id, leftFeatures), new TreeFeature(id, rightFeatures));
}
Also used : IntArrayContainer(org.tribuo.common.tree.impl.IntArrayContainer) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 17 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class BinaryCrossEntropy method lossAndGradient.

/**
 * Returns a {@link Pair} of {@link Double} and {@link SGDVector} representing the loss
 * and per label gradients respectively.
 * <p>
 * The prediction vector is transformed to produce the per label gradient and returned.
 * @param truth The true label id
 * @param prediction The prediction for each label id
 * @return A Pair of the score and per label gradient.
 */
@Override
public Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) {
    DenseVector labels, densePred;
    if (truth instanceof SparseVector) {
        labels = ((SparseVector) truth).densify();
    } else {
        labels = (DenseVector) truth;
    }
    if (prediction instanceof SparseVector) {
        densePred = ((SparseVector) prediction).densify();
    } else {
        densePred = (DenseVector) prediction;
    }
    double loss = 0.0;
    for (int i = 0; i < prediction.size(); i++) {
        double label = labels.get(i);
        double pred = densePred.get(i);
        double yhat = SigmoidNormalizer.sigmoid(pred);
        // numerically stable form of loss computation
        loss += Math.max(pred, 0) - (pred * label) + Math.log1p(Math.exp(-Math.abs(pred)));
        densePred.set(i, -(yhat - label));
    }
    return new Pair<>(loss, densePred);
}
Also used : SparseVector(org.tribuo.math.la.SparseVector) DenseVector(org.tribuo.math.la.DenseVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 18 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class Hinge method lossAndGradient.

/**
 * Returns a {@link Pair} of {@link Double} and {@link SGDVector} representing the loss
 * and per label gradients respectively.
 * @param truth The true label id.
 * @param prediction The prediction for each label id.
 * @return The loss and per label gradient.
 */
@Override
public Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) {
    DenseVector labels, densePred;
    if (truth instanceof SparseVector) {
        labels = ((SparseVector) truth).densify();
    } else {
        labels = (DenseVector) truth;
    }
    if (prediction instanceof SparseVector) {
        densePred = ((SparseVector) prediction).densify();
    } else {
        densePred = (DenseVector) prediction;
    }
    double loss = 0.0;
    for (int i = 0; i < labels.size(); i++) {
        double lbl = labels.get(i) == 0.0 ? -1 : 1.0;
        double pred = densePred.get(i);
        double score = lbl * pred;
        if (score < margin) {
            densePred.set(i, lbl);
        } else {
            densePred.set(i, 0.0);
        }
        loss += Math.max(0.0, margin - score);
    }
    return new Pair<>(loss, densePred);
}
Also used : SparseVector(org.tribuo.math.la.SparseVector) DenseVector(org.tribuo.math.la.DenseVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 19 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class NeighbourQueryTestHelper method neighboursQueryOneExclusive.

static void neighboursQueryOneExclusive(NeighboursQueryFactory nqf) {
    SGDVector[] data = getShuffledTestDataVectorArray();
    // This point is excluded from the set of points being queried
    SGDVector vector = get2DPoint(5.22, 5.25);
    NeighboursQuery nq = nqf.createNeighboursQuery(data);
    List<Pair<Integer, Double>> indexDistancePairList = nq.query(vector, 3);
    // This helper uses k = 3. These are the expected neighboring points.
    // The point itself is returned as the first neighbouring point, so we don't check that.
    double[] firstExpectedPoint0 = { 5.21, 5.28 };
    double[] firstExpectedPoint1 = { 5.23, 5.02 };
    double[] firstExpectedPoint2 = { 4.95, 5.25 };
    assertNeighbourPoints(data, indexDistancePairList, firstExpectedPoint0, firstExpectedPoint1, firstExpectedPoint2);
    // This helper uses k = 3. These are the expected neighboring distances.
    double expectedDistance0 = 0.031622776601683965;
    double expectedDistance1 = 0.23021728866442723;
    double expectedDistance2 = 0.2699999999999996;
    assertNeighbourDistances(indexDistancePairList, expectedDistance0, expectedDistance1, expectedDistance2);
}
Also used : SGDVector(org.tribuo.math.la.SGDVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Example 20 with Pair

use of com.oracle.labs.mlrg.olcut.util.Pair in project tribuo by oracle.

the class NeighbourQueryTestHelper method neighboursQueryOneInclusive.

static void neighboursQueryOneInclusive(NeighboursQueryFactory nqf) {
    SGDVector[] data = getTestDataVectorArray();
    // This point is included in the set of points being queried
    SGDVector vector = get2DPoint(5.21, 5.28);
    NeighboursQuery nq = nqf.createNeighboursQuery(data);
    List<Pair<Integer, Double>> indexDistancePairList = nq.query(vector, 3);
    // This helper uses k = 3. These are the expected neighboring points.
    // The point itself is returned as the first neighbouring point, so we don't check that.
    double[] firstExpectedPoint0 = { 5.21, 5.28 };
    double[] firstExpectedPoint1 = { 5.23, 5.02 };
    double[] firstExpectedPoint2 = { 4.95, 5.25 };
    assertNeighbourPoints(data, indexDistancePairList, firstExpectedPoint0, firstExpectedPoint1, firstExpectedPoint2);
    // This helper uses k = 3. These are the expected neighboring distances.
    double expectedDistance0 = 0.0;
    double expectedDistance1 = 0.2607680962081067;
    double expectedDistance2 = 0.26172504656604784;
    assertNeighbourDistances(indexDistancePairList, expectedDistance0, expectedDistance1, expectedDistance2);
}
Also used : SGDVector(org.tribuo.math.la.SGDVector) Pair(com.oracle.labs.mlrg.olcut.util.Pair)

Aggregations

Pair (com.oracle.labs.mlrg.olcut.util.Pair)59 ArrayList (java.util.ArrayList)27 List (java.util.List)21 HashMap (java.util.HashMap)18 MutableDataset (org.tribuo.MutableDataset)17 SimpleDataSourceProvenance (org.tribuo.provenance.SimpleDataSourceProvenance)16 Label (org.tribuo.classification.Label)14 Feature (org.tribuo.Feature)11 Regressor (org.tribuo.regression.Regressor)11 Prediction (org.tribuo.Prediction)10 DenseVector (org.tribuo.math.la.DenseVector)10 SparseVector (org.tribuo.math.la.SparseVector)10 SGDVector (org.tribuo.math.la.SGDVector)9 Map (java.util.Map)7 Example (org.tribuo.Example)7 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)7 PriorityQueue (java.util.PriorityQueue)6 Excuse (org.tribuo.Excuse)5 Model (org.tribuo.Model)5 LabelFactory (org.tribuo.classification.LabelFactory)5