Search in sources :

Example 6 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class Nd4jTestsC method testCompareAndSet1.

@Test
public void testCompareAndSet1() {
    INDArray array = Nd4j.zeros(25);
    INDArray assertion = Nd4j.zeros(25);
    array.putScalar(0, 0.1f);
    array.putScalar(10, 0.1f);
    array.putScalar(20, 0.1f);
    Nd4j.getExecutioner().exec(new CompareAndSet(array, 0.1, 0.0, 0.01));
    assertEquals(assertion, array);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test)

Example 7 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project deeplearning4j by deeplearning4j.

the class ROC method eval.

/**
     * Evaluate (collect statistics for) the given minibatch of data.
     * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
     *
     * @param labels      Labels / true outcomes
     * @param predictions Predictions
     */
public void eval(INDArray labels, INDArray predictions) {
    if (labels.rank() == 3 && predictions.rank() == 3) {
        //Assume time series input -> reshape to 2d
        evalTimeSeries(labels, predictions);
    }
    if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1) || labels.size(1) > 2) {
        throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
    }
    double step = 1.0 / thresholdSteps;
    boolean singleOutput = labels.size(1) == 1;
    INDArray positivePredictedClassColumn;
    INDArray positiveActualClassColumn;
    INDArray negativeActualClassColumn;
    if (singleOutput) {
        //Single binary variable case
        positiveActualClassColumn = labels;
        //1.0 - label
        negativeActualClassColumn = labels.rsub(1.0);
        positivePredictedClassColumn = predictions;
    } else {
        //Standard case - 2 output variables (probability distribution)
        positiveActualClassColumn = labels.getColumn(1);
        negativeActualClassColumn = labels.getColumn(0);
        positivePredictedClassColumn = predictions.getColumn(1);
    }
    //Increment global counts - actual positive/negative observed
    countActualPositive += positiveActualClassColumn.sumNumber().intValue();
    countActualNegative += negativeActualClassColumn.sumNumber().intValue();
    for (int i = 0; i <= thresholdSteps; i++) {
        double currThreshold = i * step;
        //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
        Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
        Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
        Op op = new CompareAndSet(positivePredictedClassColumn.dup(), 1.0, condGeq);
        INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
        op = new CompareAndSet(predictedClass1, 0.0, condLeq);
        predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
        //True positives: occur when positive predicted class and actual positive actual class...
        //False positive occurs when positive predicted class, but negative actual class
        //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
        INDArray isTruePositive = predictedClass1.mul(positiveActualClassColumn);
        //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
        INDArray isFalsePositive = predictedClass1.mul(negativeActualClassColumn);
        //Counts for this batch:
        int truePositiveCount = isTruePositive.sumNumber().intValue();
        int falsePositiveCount = isFalsePositive.sumNumber().intValue();
        //Increment counts for this thold
        CountsForThreshold thresholdCounts = counts.get(currThreshold);
        thresholdCounts.incrementTruePositive(truePositiveCount);
        thresholdCounts.incrementFalsePositive(falsePositiveCount);
    }
}
Also used : Condition(org.nd4j.linalg.indexing.conditions.Condition) Op(org.nd4j.linalg.api.ops.Op) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)

Example 8 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project deeplearning4j by deeplearning4j.

the class ROCMultiClass method eval.

/**
     * Evaluate (collect statistics for) the given minibatch of data.
     * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
     *
     * @param labels      Labels / true outcomes
     * @param predictions Predictions
     */
public void eval(INDArray labels, INDArray predictions) {
    if (labels.rank() == 3 && predictions.rank() == 3) {
        //Assume time series input -> reshape to 2d
        evalTimeSeries(labels, predictions);
    }
    if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) {
        throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
    }
    double step = 1.0 / thresholdSteps;
    if (countActualPositive == null) {
        //This must be the first time eval has been called...
        int size = labels.size(1);
        countActualPositive = new long[size];
        countActualNegative = new long[size];
        for (int i = 0; i < size; i++) {
            Map<Double, ROC.CountsForThreshold> map = new LinkedHashMap<Double, ROC.CountsForThreshold>();
            counts.put(i, map);
            for (int j = 0; j <= thresholdSteps; j++) {
                double currThreshold = j * step;
                map.put(currThreshold, new ROC.CountsForThreshold(currThreshold));
            }
        }
    }
    if (countActualPositive.length != labels.size(1)) {
        throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. " + "Got " + labels.size(1) + " labels (from array shape " + Arrays.toString(labels.shape()) + ")" + " vs. expected number of label classes = " + countActualPositive.length);
    }
    for (int i = 0; i < countActualPositive.length; i++) {
        //Iterate over each class
        INDArray positiveActualColumn = labels.getColumn(i);
        INDArray positivePredictedColumn = predictions.getColumn(i);
        //Increment global counts - actual positive/negative observed
        long currBatchPositiveActualCount = positiveActualColumn.sumNumber().intValue();
        countActualPositive[i] += currBatchPositiveActualCount;
        countActualNegative[i] += positiveActualColumn.length() - currBatchPositiveActualCount;
        for (int j = 0; j <= thresholdSteps; j++) {
            double currThreshold = j * step;
            //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
            Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
            Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
            Op op = new CompareAndSet(positivePredictedColumn.dup(), 1.0, condGeq);
            INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
            op = new CompareAndSet(predictedClass1, 0.0, condLeq);
            predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
            //True positives: occur when positive predicted class and actual positive actual class...
            //False positive occurs when positive predicted class, but negative actual class
            //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
            INDArray isTruePositive = predictedClass1.mul(positiveActualColumn);
            INDArray negativeActualColumn = positiveActualColumn.rsub(1.0);
            //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
            INDArray isFalsePositive = predictedClass1.mul(negativeActualColumn);
            //Counts for this batch:
            int truePositiveCount = isTruePositive.sumNumber().intValue();
            int falsePositiveCount = isFalsePositive.sumNumber().intValue();
            //Increment counts for this threshold
            ROC.CountsForThreshold thresholdCounts = counts.get(i).get(currThreshold);
            thresholdCounts.incrementTruePositive(truePositiveCount);
            thresholdCounts.incrementFalsePositive(falsePositiveCount);
        }
    }
}
Also used : Condition(org.nd4j.linalg.indexing.conditions.Condition) Op(org.nd4j.linalg.api.ops.Op) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)

Example 9 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testCaSTransform2.

@Test
public void testCaSTransform2() throws Exception {
    INDArray array = Nd4j.create(new double[] { 1, 2, 0, 4, 5 });
    INDArray comp = Nd4j.create(new double[] { 3, 2, 3, 4, 5 });
    Nd4j.getExecutioner().exec(new CompareAndSet(array, 3.0, Conditions.lessThan(2)));
    assertEquals(comp, array);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 10 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testCaSTransform1.

@Test
public void testCaSTransform1() throws Exception {
    INDArray array = Nd4j.create(new double[] { 1, 2, 0, 4, 5 });
    INDArray comp = Nd4j.create(new double[] { 1, 2, 3, 4, 5 });
    Nd4j.getExecutioner().exec(new CompareAndSet(array, 3, Conditions.equals(0)));
    assertEquals(comp, array);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

CompareAndSet (org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)10 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 Test (org.junit.Test)7 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)5 Op (org.nd4j.linalg.api.ops.Op)2 Condition (org.nd4j.linalg.indexing.conditions.Condition)2 Function (com.google.common.base.Function)1 IComplexNumber (org.nd4j.linalg.api.complex.IComplexNumber)1 CoordinateFunction (org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction)1 BaseCondition (org.nd4j.linalg.indexing.conditions.BaseCondition)1