Search in sources :

Example 1 with Condition

use of org.nd4j.linalg.indexing.conditions.Condition in project deeplearning4j by deeplearning4j.

the class ROC method eval.

/**
     * Evaluate (collect statistics for) the given minibatch of data.
     * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
     *
     * @param labels      Labels / true outcomes
     * @param predictions Predictions
     */
public void eval(INDArray labels, INDArray predictions) {
    if (labels.rank() == 3 && predictions.rank() == 3) {
        //Assume time series input -> reshape to 2d
        evalTimeSeries(labels, predictions);
    }
    if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1) || labels.size(1) > 2) {
        throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
    }
    double step = 1.0 / thresholdSteps;
    boolean singleOutput = labels.size(1) == 1;
    INDArray positivePredictedClassColumn;
    INDArray positiveActualClassColumn;
    INDArray negativeActualClassColumn;
    if (singleOutput) {
        //Single binary variable case
        positiveActualClassColumn = labels;
        //1.0 - label
        negativeActualClassColumn = labels.rsub(1.0);
        positivePredictedClassColumn = predictions;
    } else {
        //Standard case - 2 output variables (probability distribution)
        positiveActualClassColumn = labels.getColumn(1);
        negativeActualClassColumn = labels.getColumn(0);
        positivePredictedClassColumn = predictions.getColumn(1);
    }
    //Increment global counts - actual positive/negative observed
    countActualPositive += positiveActualClassColumn.sumNumber().intValue();
    countActualNegative += negativeActualClassColumn.sumNumber().intValue();
    for (int i = 0; i <= thresholdSteps; i++) {
        double currThreshold = i * step;
        //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
        Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
        Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
        Op op = new CompareAndSet(positivePredictedClassColumn.dup(), 1.0, condGeq);
        INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
        op = new CompareAndSet(predictedClass1, 0.0, condLeq);
        predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
        //True positives: occur when positive predicted class and actual positive actual class...
        //False positive occurs when positive predicted class, but negative actual class
        //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
        INDArray isTruePositive = predictedClass1.mul(positiveActualClassColumn);
        //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
        INDArray isFalsePositive = predictedClass1.mul(negativeActualClassColumn);
        //Counts for this batch:
        int truePositiveCount = isTruePositive.sumNumber().intValue();
        int falsePositiveCount = isFalsePositive.sumNumber().intValue();
        //Increment counts for this thold
        CountsForThreshold thresholdCounts = counts.get(currThreshold);
        thresholdCounts.incrementTruePositive(truePositiveCount);
        thresholdCounts.incrementFalsePositive(falsePositiveCount);
    }
}
Also used : Condition(org.nd4j.linalg.indexing.conditions.Condition) Op(org.nd4j.linalg.api.ops.Op) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)

Example 2 with Condition

use of org.nd4j.linalg.indexing.conditions.Condition in project deeplearning4j by deeplearning4j.

the class ROCMultiClass method eval.

/**
     * Evaluate (collect statistics for) the given minibatch of data.
     * For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
     *
     * @param labels      Labels / true outcomes
     * @param predictions Predictions
     */
public void eval(INDArray labels, INDArray predictions) {
    if (labels.rank() == 3 && predictions.rank() == 3) {
        //Assume time series input -> reshape to 2d
        evalTimeSeries(labels, predictions);
    }
    if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) {
        throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
    }
    double step = 1.0 / thresholdSteps;
    if (countActualPositive == null) {
        //This must be the first time eval has been called...
        int size = labels.size(1);
        countActualPositive = new long[size];
        countActualNegative = new long[size];
        for (int i = 0; i < size; i++) {
            Map<Double, ROC.CountsForThreshold> map = new LinkedHashMap<Double, ROC.CountsForThreshold>();
            counts.put(i, map);
            for (int j = 0; j <= thresholdSteps; j++) {
                double currThreshold = j * step;
                map.put(currThreshold, new ROC.CountsForThreshold(currThreshold));
            }
        }
    }
    if (countActualPositive.length != labels.size(1)) {
        throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. " + "Got " + labels.size(1) + " labels (from array shape " + Arrays.toString(labels.shape()) + ")" + " vs. expected number of label classes = " + countActualPositive.length);
    }
    for (int i = 0; i < countActualPositive.length; i++) {
        //Iterate over each class
        INDArray positiveActualColumn = labels.getColumn(i);
        INDArray positivePredictedColumn = predictions.getColumn(i);
        //Increment global counts - actual positive/negative observed
        long currBatchPositiveActualCount = positiveActualColumn.sumNumber().intValue();
        countActualPositive[i] += currBatchPositiveActualCount;
        countActualNegative[i] += positiveActualColumn.length() - currBatchPositiveActualCount;
        for (int j = 0; j <= thresholdSteps; j++) {
            double currThreshold = j * step;
            //Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
            Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
            Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
            Op op = new CompareAndSet(positivePredictedColumn.dup(), 1.0, condGeq);
            INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
            op = new CompareAndSet(predictedClass1, 0.0, condLeq);
            predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
            //True positives: occur when positive predicted class and actual positive actual class...
            //False positive occurs when positive predicted class, but negative actual class
            //If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
            INDArray isTruePositive = predictedClass1.mul(positiveActualColumn);
            INDArray negativeActualColumn = positiveActualColumn.rsub(1.0);
            //If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
            INDArray isFalsePositive = predictedClass1.mul(negativeActualColumn);
            //Counts for this batch:
            int truePositiveCount = isTruePositive.sumNumber().intValue();
            int falsePositiveCount = isFalsePositive.sumNumber().intValue();
            //Increment counts for this threshold
            ROC.CountsForThreshold thresholdCounts = counts.get(i).get(currThreshold);
            thresholdCounts.incrementTruePositive(truePositiveCount);
            thresholdCounts.incrementFalsePositive(falsePositiveCount);
        }
    }
}
Also used : Condition(org.nd4j.linalg.indexing.conditions.Condition) Op(org.nd4j.linalg.api.ops.Op) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)

Example 3 with Condition

use of org.nd4j.linalg.indexing.conditions.Condition in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testAbsValueGreaterThan.

@Test
public void testAbsValueGreaterThan() {
    final double threshold = 2;
    Condition absValueCondition = new AbsValueGreaterThan(threshold);
    Function<Number, Number> clipFn = new Function<Number, Number>() {

        @Override
        public Number apply(Number number) {
            System.out.println("Number: " + number.doubleValue());
            return (number.doubleValue() > threshold ? threshold : -threshold);
        }
    };
    Nd4j.getRandom().setSeed(12345);
    // Random numbers: -3 to 3
    INDArray orig = Nd4j.rand(1, 20).muli(6).subi(3);
    INDArray exp = orig.dup();
    INDArray after = orig.dup();
    for (int i = 0; i < exp.length(); i++) {
        double d = exp.getDouble(i);
        if (d > threshold) {
            exp.putScalar(i, threshold);
        } else if (d < -threshold) {
            exp.putScalar(i, -threshold);
        }
    }
    BooleanIndexing.applyWhere(after, absValueCondition, clipFn);
    System.out.println(orig);
    System.out.println(exp);
    System.out.println(after);
    assertEquals(exp, after);
}
Also used : Condition(org.nd4j.linalg.indexing.conditions.Condition) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) Function(com.google.common.base.Function) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AbsValueGreaterThan(org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 Condition (org.nd4j.linalg.indexing.conditions.Condition)3 Op (org.nd4j.linalg.api.ops.Op)2 CompareAndSet (org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)2 Function (com.google.common.base.Function)1 Test (org.junit.Test)1 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)1 MatchCondition (org.nd4j.linalg.api.ops.impl.accum.MatchCondition)1 AbsValueGreaterThan (org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan)1