use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project deeplearning4j by deeplearning4j.
the class Evaluation method eval.
/**
* Evaluate the network, with optional metadata
*
* @param realOutcomes Data labels
* @param guesses Network predictions
* @param recordMetaData Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
*
*/
@Override
public void eval(INDArray realOutcomes, INDArray guesses, List<? extends Serializable> recordMetaData) {
// Add the number of rows to numRowCounter
numRowCounter += realOutcomes.shape()[0];
// If confusion is null, then Evaluation was instantiated without providing the classes -> infer # classes from
if (confusion == null) {
int nClasses = realOutcomes.columns();
if (nClasses == 1)
//Binary (single output variable) case
nClasses = 2;
labelsList = new ArrayList<>(nClasses);
for (int i = 0; i < nClasses; i++) labelsList.add(String.valueOf(i));
createConfusion(nClasses);
}
// Length of real labels must be same as length of predicted labels
if (realOutcomes.length() != guesses.length())
throw new IllegalArgumentException("Unable to evaluate. Outcome matrices not same length");
// For each row get the most probable label (column) from prediction and assign as guessMax
// For each row get the column of the true label and assign as currMax
int nCols = realOutcomes.columns();
int nRows = realOutcomes.rows();
if (nCols == 1) {
INDArray binaryGuesses = guesses.gt(0.5);
int tp = binaryGuesses.mul(realOutcomes).sumNumber().intValue();
int fp = binaryGuesses.mul(-1.0).addi(1.0).muli(realOutcomes).sumNumber().intValue();
int fn = binaryGuesses.mul(realOutcomes.mul(-1.0).addi(1.0)).sumNumber().intValue();
int tn = nRows - tp - fp - fn;
confusion.add(1, 1, tp);
confusion.add(1, 0, fn);
confusion.add(0, 1, fp);
confusion.add(0, 0, tn);
truePositives.incrementCount(1, tp);
falsePositives.incrementCount(1, fp);
falseNegatives.incrementCount(1, fp);
trueNegatives.incrementCount(1, tp);
truePositives.incrementCount(0, tn);
falsePositives.incrementCount(0, fn);
falseNegatives.incrementCount(0, fn);
trueNegatives.incrementCount(0, tn);
if (recordMetaData != null) {
for (int i = 0; i < binaryGuesses.size(0); i++) {
if (i >= recordMetaData.size())
break;
int actual = realOutcomes.getDouble(0) == 0.0 ? 0 : 1;
int predicted = binaryGuesses.getDouble(0) == 0.0 ? 0 : 1;
addToMetaConfusionMatrix(actual, predicted, recordMetaData.get(i));
}
}
} else {
INDArray guessIndex = Nd4j.argMax(guesses, 1);
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = guessIndex.length();
for (int i = 0; i < nExamples; i++) {
int actual = (int) realOutcomeIndex.getDouble(i);
int predicted = (int) guessIndex.getDouble(i);
confusion.add(actual, predicted);
if (recordMetaData != null && recordMetaData.size() > i) {
Object m = recordMetaData.get(i);
addToMetaConfusionMatrix(actual, predicted, m);
}
}
for (int col = 0; col < nCols; col++) {
INDArray colBinaryGuesses = guessIndex.eps(col);
INDArray colRealOutcomes = realOutcomes.getColumn(col);
int colTp = colBinaryGuesses.mul(colRealOutcomes).sumNumber().intValue();
int colFp = colBinaryGuesses.mul(colRealOutcomes.mul(-1.0).addi(1.0)).sumNumber().intValue();
int colFn = colBinaryGuesses.mul(-1.0).addi(1.0).muli(colRealOutcomes).sumNumber().intValue();
int colTn = nRows - colTp - colFp - colFn;
truePositives.incrementCount(col, colTp);
falsePositives.incrementCount(col, colFp);
falseNegatives.incrementCount(col, colFn);
trueNegatives.incrementCount(col, colTn);
}
}
if (nCols > 1 && topN > 1) {
//Calculate top N accuracy
//TODO: this could be more efficient
INDArray realOutcomeIndex = Nd4j.argMax(realOutcomes, 1);
int nExamples = realOutcomeIndex.length();
for (int i = 0; i < nExamples; i++) {
int labelIdx = (int) realOutcomeIndex.getDouble(i);
double prob = guesses.getDouble(i, labelIdx);
INDArray row = guesses.getRow(i);
int countGreaterThan = (int) Nd4j.getExecutioner().exec(new MatchCondition(row, Conditions.greaterThan(prob)), Integer.MAX_VALUE).getDouble(0);
if (countGreaterThan < topN) {
//For example, for top 3 accuracy: can have at most 2 other probabilities larger
topNCorrectCount++;
}
topNTotalCount++;
}
}
}
use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.
the class BooleanIndexing method and.
/**
* And over the whole ndarray given some condition
*
* @param n the ndarray to test
* @param cond the condition to test against
* @return true if all of the elements meet the specified
* condition false otherwise
*/
public static boolean and(final INDArray n, final Condition cond) {
if (cond instanceof BaseCondition) {
long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0);
if (val == n.lengthLong())
return true;
else
return false;
} else {
boolean ret = true;
final AtomicBoolean a = new AtomicBoolean(ret);
Shape.iterate(n, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (a.get())
a.compareAndSet(true, a.get() && cond.apply(n.getDouble(coord[0])));
}
});
return a.get();
}
}
use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.
the class BooleanIndexingTest method testMatchConditionAllDimensions2.
@Test
public void testMatchConditionAllDimensions2() throws Exception {
INDArray array = Nd4j.create(new double[] { 0, 1, 2, 3, Double.NaN, 5, 6, 7, 8, 9 });
int val = (int) Nd4j.getExecutioner().exec(new MatchCondition(array, Conditions.isNan()), Integer.MAX_VALUE).getDouble(0);
assertEquals(1, val);
}
use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.
the class NativeOpExecutionerTest method testInf.
@Test
public void testInf() {
Nd4j.setDataType(DataBuffer.Type.FLOAT);
INDArray x = Nd4j.create(10, 10);
x.minNumber();
MatchCondition condition = new MatchCondition(x, Conditions.isInfinite());
int match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
log.info("Matches: {}", match);
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.INF_PANIC);
x = Nd4j.create(10, 10);
x.minNumber();
}
use of org.nd4j.linalg.api.ops.impl.accum.MatchCondition in project nd4j by deeplearning4j.
the class ShufflesTest method testBinomial.
@Test
public void testBinomial() {
Distribution distribution = Nd4j.getDistributions().createBinomial(3, Nd4j.create(10).putScalar(1, 0.00001));
for (int x = 0; x < 10000; x++) {
INDArray z = distribution.sample(new int[] { 1, 10 });
System.out.println();
MatchCondition condition = new MatchCondition(z, Conditions.equals(0.0));
int match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
assertEquals(z.length(), match);
}
}
Aggregations