Search in sources :

Example 1 with MatchCondition

use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project deeplearning4j by deeplearning4j.

the class Evaluation method eval.

/**
     * Evaluate the network, with optional metadata
     *
     * @param realOutcomes   Data labels
     * @param guesses        Network predictions
     * @param recordMetaData Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
     *
     */
@Override
public void eval(INDArray realOutcomes, INDArray guesses, List<? extends Serializable> recordMetaData) {
    // Add the number of rows to numRowCounter
    numRowCounter += realOutcomes.shape()[0];
    // If confusion is null, then Evaluation was instantiated without providing the classes -> infer # classes from
    if (confusion == null) {
        int nClasses = realOutcomes.columns();
        if (nClasses == 1)
            //Binary (single output variable) case
            nClasses = 2;
        labelsList = new ArrayList<>(nClasses);
        for (int i = 0; i < nClasses; i++) labelsList.add(String.valueOf(i));
        createConfusion(nClasses);
    }
    // Length of real labels must be same as length of predicted labels
    if (realOutcomes.length() != guesses.length())
        throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
    // For each row get the most probable label (column) from prediction and assign as guessMax
    // For each row get the column of the true label and assign as currMax
    int nCols = realOutcomes.columns();
    int nRows = realOutcomes.rows();
    if (nCols == 1) {
        INDArray binaryGuesses = guesses.gt(0.5);
        int tp = binaryGuesses.mul(realOutcomes).sumNumber().intValue();
        int fp = binaryGuesses.mul(-1.0).addi(1.0).muli(realOutcomes).sumNumber().intValue();
        int fn = binaryGuesses.mul(realOutcomes.mul(-1.0).addi(1.0)).sumNumber().intValue();
        int tn = nRows - tp - fp - fn;
        confusion.add(1, 1, tp);
        confusion.add(1, 0, fn);
        confusion.add(0, 1, fp);
        confusion.add(0, 0, tn);
        truePositives.incrementCount(1, tp);
        falsePositives.incrementCount(1, fp);
        falseNegatives.incrementCount(1, fp);
        trueNegatives.incrementCount(1, tp);
        truePositives.incrementCount(0, tn);
        falsePositives.incrementCount(0, fn);
        falseNegatives.incrementCount(0, fn);
        trueNegatives.incrementCount(0, tn);
        if (recordMetaData != null) {
            for (int i = 0; i < binaryGuesses.size(0); i++) {
                if (i >= recordMetaData.size())
                    break;
                int actual = realOutcomes.getDouble(0) == 0.0 ? 0 : 1;
                int predicted = binaryGuesses.getDouble(0) == 0.0 ? 0 : 1;
                addToMetaConfusionMatrix(actual, predicted, recordMetaData.get(i));
            }
        }
    } else {
        INDArray guessIndex = Nd4j.argMax(guesses, 1);
        INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
        int nExamples = guessIndex.length();
        for (int i = 0; i < nExamples; i++) {
            int actual = (int) realOutcomeIndex.getDouble(i);
            int predicted = (int) guessIndex.getDouble(i);
            confusion.add(actual, predicted);
            if (recordMetaData != null && recordMetaData.size() > i) {
                Object m = recordMetaData.get(i);
                addToMetaConfusionMatrix(actual, predicted, m);
            }
        }
        for (int col = 0; col < nCols; col++) {
            INDArray colBinaryGuesses = guessIndex.eps(col);
            INDArray colRealOutcomes = realOutcomes.getColumn(col);
            int colTp = colBinaryGuesses.mul(colRealOutcomes).sumNumber().intValue();
            int colFp = colBinaryGuesses.mul(colRealOutcomes.mul(-1.0).addi(1.0)).sumNumber().intValue();
            int colFn = colBinaryGuesses.mul(-1.0).addi(1.0).muli(colRealOutcomes).sumNumber().intValue();
            int colTn = nRows - colTp - colFp - colFn;
            truePositives.incrementCount(col, colTp);
            falsePositives.incrementCount(col, colFp);
            falseNegatives.incrementCount(col, colFn);
            trueNegatives.incrementCount(col, colTn);
        }
    }
    if (nCols > 1 && topN > 1) {
        //Calculate top N accuracy
        //TODO: this could be more efficient
        INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
        int nExamples = realOutcomeIndex.length();
        for (int i = 0; i < nExamples; i++) {
            int labelIdx = (int) realOutcomeIndex.getDouble(i);
            double prob = guesses.getDouble(i, labelIdx);
            INDArray row = guesses.getRow(i);
            int countGreaterThan = (int) Nd4j.getExecutioner().exec(new MatchCondition(row, Conditions.greaterThan(prob)), Integer.MAX_VALUE).getDouble(0);
            if (countGreaterThan < topN) {
                //For example, for top 3 accuracy: can have at most 2 other probabilities larger
                topNCorrectCount++;
            }
            topNTotalCount++;
        }
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition)

Example 2 with MatchCondition

use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.

the class BooleanIndexing method and.

/**
 * And over the whole ndarray given some condition
 *
 * @param n    the ndarray to test
 * @param cond the condition to test against
 * @return true if all of the elements meet the specified
 * condition false otherwise
 */
public static boolean and(final INDArray n, final Condition cond) {
    if (cond instanceof BaseCondition) {
        long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0);
        if (val == n.lengthLong())
            return true;
        else
            return false;
    } else {
        boolean ret = true;
        final AtomicBoolean a = new AtomicBoolean(ret);
        Shape.iterate(n, new CoordinateFunction() {

            @Override
            public void process(int[]... coord) {
                if (a.get())
                    a.compareAndSet(true, a.get() && cond.apply(n.getDouble(coord[0])));
            }
        });
        return a.get();
    }
}
Also used : AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) BaseCondition(org.nd4j.linalg.indexing.conditions.BaseCondition) CoordinateFunction(org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition)

Example 3 with MatchCondition

use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testMatchConditionAllDimensions2.

@Test
public void testMatchConditionAllDimensions2() throws Exception {
    INDArray array = Nd4j.create(new double[] { 0, 1, 2, 3, Double.NaN, 5, 6, 7, 8, 9 });
    int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isNan()), Integer.MAX_VALUE).getDouble(0);
    assertEquals(1, val);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 4 with MatchCondition

use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.

the class NativeOpExecutionerTest method testInf.

@Test
public void testInf() {
    Nd4j.setDataType(DataBuffer.Type.FLOAT);
    INDArray x = Nd4j.create(10, 10);
    x.minNumber();
    MatchCondition condition = new MatchCondition(x, Conditions.isInfinite());
    int match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
    log.info("Matches: {}", match);
    Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC);
    x = Nd4j.create(10, 10);
    x.minNumber();
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) Test(org.junit.Test)

Example 5 with MatchCondition

use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.

the class ShufflesTest method testBinomial.

@Test
public void testBinomial() {
    Distribution distribution = Nd4j.getDistributions().createBinomial(3, Nd4j.create(10).putScalar(1, 0.00001));
    for (int x = 0; x < 10000; x++) {
        INDArray z = distribution.sample(new int[] { 1, 10 });
        System.out.println();
        MatchCondition condition = new MatchCondition(z, Conditions.equals(0.0));
        int match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
        assertEquals(z.length(), match);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) Test(org.junit.Test)

Aggregations

MatchCondition (org.nd4j.linalg.api.ops.impl.accum.MatchCondition)17 INDArray (org.nd4j.linalg.api.ndarray.INDArray)12 Test (org.junit.Test)8 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)6 BaseCondition (org.nd4j.linalg.indexing.conditions.BaseCondition)4 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)2 CoordinateFunction (org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction)2 CompressionDescriptor (org.nd4j.linalg.compression.CompressionDescriptor)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 IntPointer (org.bytedeco.javacpp.IntPointer)1 IActivation (org.nd4j.linalg.activations.IActivation)1 ActivationSigmoid (org.nd4j.linalg.activations.impl.ActivationSigmoid)1 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)1 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)1 DefaultRandom (org.nd4j.linalg.api.rng.DefaultRandom)1 Random (org.nd4j.linalg.api.rng.Random)1 Distribution (org.nd4j.linalg.api.rng.distribution.Distribution)1 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)1 NativeRandom (org.nd4j.rng.NativeRandom)1