Search in sources :

Example 1 with Pair

use of ai.djl.util.Pair in project djl by deepjavalibrary.

the class RandomAccessDataset method toArray.

/**
 * Returns the dataset contents as a Java array.
 *
 * <p>Each Number[] is a flattened dataset record and the Number[][] is the array of all
 * records.
 *
 * @return the dataset contents as a Java array
 * @throws IOException for various exceptions depending on the dataset
 * @throws TranslateException if there is an error while processing input
 */
public Pair<Number[][], Number[][]> toArray() throws IOException, TranslateException {
    try (NDManager manager = NDManager.newBaseManager()) {
        Sampler sampl = new BatchSampler(new SequenceSampler(), 1, false);
        int size = Math.toIntExact(size());
        Number[][] data = new Number[size][];
        Number[][] labels = new Number[size][];
        int index = 0;
        for (Batch batch : this.getData(manager, sampl)) {
            data[index] = flattenRecord(batch.getData());
            labels[index] = flattenRecord(batch.getLabels());
            batch.close();
            index++;
        }
        return new Pair<>(data, labels);
    }
}
Also used : NDManager(ai.djl.ndarray.NDManager) Pair(ai.djl.util.Pair)

Example 2 with Pair

use of ai.djl.util.Pair in project djl by deepjavalibrary.

the class SingleShotDetectionLoss method inputForComponent.

/**
 * Calculate loss between label and prediction.
 *
 * @param labels target labels. Must contain (offsetLabels, masks, classlabels). This is
 *     returned by MultiBoxTarget function
 * @param predictions predicted labels (class prediction, offset prediction)
 * @return loss value
 */
@Override
protected Pair<NDList, NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions) {
    NDArray anchors = predictions.get(0);
    NDArray classPredictions = predictions.get(1);
    NDList targets = multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
    switch(componentIndex) {
        case // ClassLoss
        0:
            NDArray classLabels = targets.get(2);
            return new Pair<>(new NDList(classLabels), new NDList(classPredictions));
        case // BoundingBoxLoss
        1:
            NDArray boundingBoxPredictions = predictions.get(2);
            NDArray boundingBoxLabels = targets.get(0);
            NDArray boundingBoxMasks = targets.get(1);
            return new Pair<>(new NDList(boundingBoxLabels.mul(boundingBoxMasks)), new NDList(boundingBoxPredictions.mul(boundingBoxMasks)));
        default:
            throw new IllegalArgumentException("Invalid component index");
    }
}
Also used : NDList(ai.djl.ndarray.NDList) NDArray(ai.djl.ndarray.NDArray) Pair(ai.djl.util.Pair)

Example 3 with Pair

use of ai.djl.util.Pair in project djl by deepjavalibrary.

the class Accuracy method accuracyHelper.

/**
 * {@inheritDoc}
 */
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
    NDArray label = labels.head();
    NDArray prediction = predictions.head();
    checkLabelShapes(label, prediction);
    NDArray predictionReduced;
    if (!label.getShape().equals(prediction.getShape())) {
        // Multi-class, sparse label
        predictionReduced = prediction.argMax(axis);
        predictionReduced = predictionReduced.reshape(label.getShape());
    } else {
        // Multi-class, one-hot label
        predictionReduced = prediction;
    }
    // result of sum is int64 now
    long total = label.size();
    try (NDArray nd = label.toType(DataType.INT64, true)) {
        NDArray correct = predictionReduced.toType(DataType.INT64, false).eq(nd).countNonzero();
        return new Pair<>(total, correct);
    }
}
Also used : NDArray(ai.djl.ndarray.NDArray) Pair(ai.djl.util.Pair)

Example 4 with Pair

use of ai.djl.util.Pair in project djl by deepjavalibrary.

the class BinaryAccuracy method accuracyHelper.

/**
 * {@inheritDoc}
 */
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
    Preconditions.checkArgument(labels.size() == predictions.size(), "labels and prediction length does not match.");
    NDArray label = labels.head();
    NDArray prediction = predictions.head();
    checkLabelShapes(label, prediction, false);
    NDArray predictionReduced = prediction.gte(threshold);
    // result of sum is int64 now
    long total = label.size();
    NDArray correct = label.toType(DataType.INT64, false).eq(predictionReduced.toType(DataType.INT64, false)).countNonzero();
    return new Pair<>(total, correct);
}
Also used : NDArray(ai.djl.ndarray.NDArray) Pair(ai.djl.util.Pair)

Example 5 with Pair

use of ai.djl.util.Pair in project djl by deepjavalibrary.

the class TopKAccuracy method accuracyHelper.

/**
 * {@inheritDoc}
 */
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
    NDArray label = labels.head();
    NDArray prediction = predictions.head();
    // number of labels and predictions should be the same
    checkLabelShapes(label, prediction);
    // ascending by default
    NDArray topKPrediction = prediction.argSort(axis).toType(DataType.INT64, false);
    int numDims = topKPrediction.getShape().dimension();
    NDArray numCorrect;
    if (numDims == 1) {
        numCorrect = topKPrediction.flatten().eq(label.flatten()).countNonzero();
    } else if (numDims == 2) {
        int numClasses = (int) topKPrediction.getShape().get(1);
        topK = Math.min(topK, numClasses);
        numCorrect = NDArrays.add(IntStream.range(0, topK).mapToObj(j -> {
            // get from last index as argSort is ascending
            NDArray jPrediction = topKPrediction.get(":, {}", numClasses - j - 1);
            return jPrediction.flatten().eq(label.flatten().toType(DataType.INT64, false)).countNonzero();
        }).toArray(NDArray[]::new));
    } else {
        throw new IllegalArgumentException("Prediction should be less than 2 dimensions");
    }
    long total = label.getShape().get(0);
    return new Pair<>(total, numCorrect);
}
Also used : NDArray(ai.djl.ndarray.NDArray) Pair(ai.djl.util.Pair)

Aggregations

Pair (ai.djl.util.Pair)16 Shape (ai.djl.ndarray.types.Shape)7 NDArray (ai.djl.ndarray.NDArray)6 Assert (org.testng.Assert)6 Test (org.testng.annotations.Test)6 IntStream (java.util.stream.IntStream)5 NDList (ai.djl.ndarray.NDList)3 NDManager (ai.djl.ndarray.NDManager)2 Trainer (ai.djl.training.Trainer)2 TrainingConfig (ai.djl.training.TrainingConfig)2 Model (ai.djl.Model)1 ModelException (ai.djl.ModelException)1 Engine (ai.djl.engine.Engine)1 Predictor (ai.djl.inference.Predictor)1 Metrics (ai.djl.metric.Metrics)1 Classifications (ai.djl.modality.Classifications)1 Image (ai.djl.modality.cv.Image)1 ImageFactory (ai.djl.modality.cv.ImageFactory)1 NDArrays (ai.djl.ndarray.NDArrays)1 Block (ai.djl.nn.Block)1