use of ai.djl.util.Pair in project djl by deepjavalibrary.
the class RandomAccessDataset method toArray.
/**
* Returns the dataset contents as a Java array.
*
* <p>Each Number[] is a flattened dataset record and the Number[][] is the array of all
* records.
*
* @return the dataset contents as a Java array
* @throws IOException for various exceptions depending on the dataset
* @throws TranslateException if there is an error while processing input
*/
public Pair<Number[][], Number[][]> toArray() throws IOException, TranslateException {
try (NDManager manager = NDManager.newBaseManager()) {
Sampler sampl = new BatchSampler(new SequenceSampler(), 1, false);
int size = Math.toIntExact(size());
Number[][] data = new Number[size][];
Number[][] labels = new Number[size][];
int index = 0;
for (Batch batch : this.getData(manager, sampl)) {
data[index] = flattenRecord(batch.getData());
labels[index] = flattenRecord(batch.getLabels());
batch.close();
index++;
}
return new Pair<>(data, labels);
}
}
use of ai.djl.util.Pair in project djl by deepjavalibrary.
the class SingleShotDetectionLoss method inputForComponent.
/**
* Calculate loss between label and prediction.
*
* @param labels target labels. Must contain (offsetLabels, masks, classlabels). This is
* returned by MultiBoxTarget function
* @param predictions predicted labels (class prediction, offset prediction)
* @return loss value
*/
@Override
protected Pair<NDList, NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions) {
NDArray anchors = predictions.get(0);
NDArray classPredictions = predictions.get(1);
NDList targets = multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
switch(componentIndex) {
case // ClassLoss
0:
NDArray classLabels = targets.get(2);
return new Pair<>(new NDList(classLabels), new NDList(classPredictions));
case // BoundingBoxLoss
1:
NDArray boundingBoxPredictions = predictions.get(2);
NDArray boundingBoxLabels = targets.get(0);
NDArray boundingBoxMasks = targets.get(1);
return new Pair<>(new NDList(boundingBoxLabels.mul(boundingBoxMasks)), new NDList(boundingBoxPredictions.mul(boundingBoxMasks)));
default:
throw new IllegalArgumentException("Invalid component index");
}
}
use of ai.djl.util.Pair in project djl by deepjavalibrary.
the class Accuracy method accuracyHelper.
/**
* {@inheritDoc}
*/
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
NDArray label = labels.head();
NDArray prediction = predictions.head();
checkLabelShapes(label, prediction);
NDArray predictionReduced;
if (!label.getShape().equals(prediction.getShape())) {
// Multi-class, sparse label
predictionReduced = prediction.argMax(axis);
predictionReduced = predictionReduced.reshape(label.getShape());
} else {
// Multi-class, one-hot label
predictionReduced = prediction;
}
// result of sum is int64 now
long total = label.size();
try (NDArray nd = label.toType(DataType.INT64, true)) {
NDArray correct = predictionReduced.toType(DataType.INT64, false).eq(nd).countNonzero();
return new Pair<>(total, correct);
}
}
use of ai.djl.util.Pair in project djl by deepjavalibrary.
the class BinaryAccuracy method accuracyHelper.
/**
* {@inheritDoc}
*/
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
Preconditions.checkArgument(labels.size() == predictions.size(), "labels and prediction length does not match.");
NDArray label = labels.head();
NDArray prediction = predictions.head();
checkLabelShapes(label, prediction, false);
NDArray predictionReduced = prediction.gte(threshold);
// result of sum is int64 now
long total = label.size();
NDArray correct = label.toType(DataType.INT64, false).eq(predictionReduced.toType(DataType.INT64, false)).countNonzero();
return new Pair<>(total, correct);
}
use of ai.djl.util.Pair in project djl by deepjavalibrary.
the class TopKAccuracy method accuracyHelper.
/**
* {@inheritDoc}
*/
@Override
protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
NDArray label = labels.head();
NDArray prediction = predictions.head();
// number of labels and predictions should be the same
checkLabelShapes(label, prediction);
// ascending by default
NDArray topKPrediction = prediction.argSort(axis).toType(DataType.INT64, false);
int numDims = topKPrediction.getShape().dimension();
NDArray numCorrect;
if (numDims == 1) {
numCorrect = topKPrediction.flatten().eq(label.flatten()).countNonzero();
} else if (numDims == 2) {
int numClasses = (int) topKPrediction.getShape().get(1);
topK = Math.min(topK, numClasses);
numCorrect = NDArrays.add(IntStream.range(0, topK).mapToObj(j -> {
// get from last index as argSort is ascending
NDArray jPrediction = topKPrediction.get(":, {}", numClasses - j - 1);
return jPrediction.flatten().eq(label.flatten().toType(DataType.INT64, false)).countNonzero();
}).toArray(NDArray[]::new));
} else {
throw new IllegalArgumentException("Prediction should be less than 2 dimensions");
}
long total = label.getShape().get(0);
return new Pair<>(total, numCorrect);
}
Aggregations