use of org.nd4j.linalg.indexing.conditions.Condition in project deeplearning4j by deeplearning4j.
the class ROC method eval.
/**
* Evaluate (collect statistics for) the given minibatch of data.
* For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
*
* @param labels Labels / true outcomes
* @param predictions Predictions
*/
public void eval(INDArray labels, INDArray predictions) {
if (labels.rank() == 3 && predictions.rank() == 3) {
//Assume time series input -> reshape to 2d
evalTimeSeries(labels, predictions);
}
if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1) || labels.size(1) > 2) {
throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
}
double step = 1.0 / thresholdSteps;
boolean singleOutput = labels.size(1) == 1;
INDArray positivePredictedClassColumn;
INDArray positiveActualClassColumn;
INDArray negativeActualClassColumn;
if (singleOutput) {
//Single binary variable case
positiveActualClassColumn = labels;
//1.0 - label
negativeActualClassColumn = labels.rsub(1.0);
positivePredictedClassColumn = predictions;
} else {
//Standard case - 2 output variables (probability distribution)
positiveActualClassColumn = labels.getColumn(1);
negativeActualClassColumn = labels.getColumn(0);
positivePredictedClassColumn = predictions.getColumn(1);
}
//Increment global counts - actual positive/negative observed
countActualPositive += positiveActualClassColumn.sumNumber().intValue();
countActualNegative += negativeActualClassColumn.sumNumber().intValue();
for (int i = 0; i <= thresholdSteps; i++) {
double currThreshold = i * step;
//Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
Op op = new CompareAndSet(positivePredictedClassColumn.dup(), 1.0, condGeq);
INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
op = new CompareAndSet(predictedClass1, 0.0, condLeq);
predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
//True positives: occur when positive predicted class and actual positive actual class...
//False positive occurs when positive predicted class, but negative actual class
//If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
INDArray isTruePositive = predictedClass1.mul(positiveActualClassColumn);
//If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
INDArray isFalsePositive = predictedClass1.mul(negativeActualClassColumn);
//Counts for this batch:
int truePositiveCount = isTruePositive.sumNumber().intValue();
int falsePositiveCount = isFalsePositive.sumNumber().intValue();
//Increment counts for this thold
CountsForThreshold thresholdCounts = counts.get(currThreshold);
thresholdCounts.incrementTruePositive(truePositiveCount);
thresholdCounts.incrementFalsePositive(falsePositiveCount);
}
}
use of org.nd4j.linalg.indexing.conditions.Condition in project deeplearning4j by deeplearning4j.
the class ROCMultiClass method eval.
/**
* Evaluate (collect statistics for) the given minibatch of data.
* For time series (3 dimensions) use {@link #evalTimeSeries(INDArray, INDArray)} or {@link #evalTimeSeries(INDArray, INDArray, INDArray)}
*
* @param labels Labels / true outcomes
* @param predictions Predictions
*/
public void eval(INDArray labels, INDArray predictions) {
if (labels.rank() == 3 && predictions.rank() == 3) {
//Assume time series input -> reshape to 2d
evalTimeSeries(labels, predictions);
}
if (labels.rank() > 2 || predictions.rank() > 2 || labels.size(1) != predictions.size(1)) {
throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(labels.shape()) + ", predictions shape = " + Arrays.toString(predictions.shape()) + "; require rank 2 array with size(1) == 1 or 2");
}
double step = 1.0 / thresholdSteps;
if (countActualPositive == null) {
//This must be the first time eval has been called...
int size = labels.size(1);
countActualPositive = new long[size];
countActualNegative = new long[size];
for (int i = 0; i < size; i++) {
Map<Double, ROC.CountsForThreshold> map = new LinkedHashMap<Double, ROC.CountsForThreshold>();
counts.put(i, map);
for (int j = 0; j <= thresholdSteps; j++) {
double currThreshold = j * step;
map.put(currThreshold, new ROC.CountsForThreshold(currThreshold));
}
}
}
if (countActualPositive.length != labels.size(1)) {
throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. " + "Got " + labels.size(1) + " labels (from array shape " + Arrays.toString(labels.shape()) + ")" + " vs. expected number of label classes = " + countActualPositive.length);
}
for (int i = 0; i < countActualPositive.length; i++) {
//Iterate over each class
INDArray positiveActualColumn = labels.getColumn(i);
INDArray positivePredictedColumn = predictions.getColumn(i);
//Increment global counts - actual positive/negative observed
long currBatchPositiveActualCount = positiveActualColumn.sumNumber().intValue();
countActualPositive[i] += currBatchPositiveActualCount;
countActualNegative[i] += positiveActualColumn.length() - currBatchPositiveActualCount;
for (int j = 0; j <= thresholdSteps; j++) {
double currThreshold = j * step;
//Work out true/false positives - do this by replacing probabilities (predictions) with 1 or 0 based on threshold
Condition condGeq = Conditions.greaterThanOrEqual(currThreshold);
Condition condLeq = Conditions.lessThanOrEqual(currThreshold);
Op op = new CompareAndSet(positivePredictedColumn.dup(), 1.0, condGeq);
INDArray predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
op = new CompareAndSet(predictedClass1, 0.0, condLeq);
predictedClass1 = Nd4j.getExecutioner().execAndReturn(op);
//True positives: occur when positive predicted class and actual positive actual class...
//False positive occurs when positive predicted class, but negative actual class
//If predicted == 1 and actual == 1 at this threshold: 1x1 = 1. 0 otherwise
INDArray isTruePositive = predictedClass1.mul(positiveActualColumn);
INDArray negativeActualColumn = positiveActualColumn.rsub(1.0);
//If predicted == 1 and actual == 0 at this threshold: 1x1 = 1. 0 otherwise
INDArray isFalsePositive = predictedClass1.mul(negativeActualColumn);
//Counts for this batch:
int truePositiveCount = isTruePositive.sumNumber().intValue();
int falsePositiveCount = isFalsePositive.sumNumber().intValue();
//Increment counts for this threshold
ROC.CountsForThreshold thresholdCounts = counts.get(i).get(currThreshold);
thresholdCounts.incrementTruePositive(truePositiveCount);
thresholdCounts.incrementFalsePositive(falsePositiveCount);
}
}
}
use of org.nd4j.linalg.indexing.conditions.Condition in project nd4j by deeplearning4j.
the class BooleanIndexingTest method testAbsValueGreaterThan.
@Test
public void testAbsValueGreaterThan() {
final double threshold = 2;
Condition absValueCondition = new AbsValueGreaterThan(threshold);
Function<Number, Number> clipFn = new Function<Number, Number>() {
@Override
public Number apply(Number number) {
System.out.println("Number: " + number.doubleValue());
return (number.doubleValue() > threshold ? threshold : -threshold);
}
};
Nd4j.getRandom().setSeed(12345);
// Random numbers: -3 to 3
INDArray orig = Nd4j.rand(1, 20).muli(6).subi(3);
INDArray exp = orig.dup();
INDArray after = orig.dup();
for (int i = 0; i < exp.length(); i++) {
double d = exp.getDouble(i);
if (d > threshold) {
exp.putScalar(i, threshold);
} else if (d < -threshold) {
exp.putScalar(i, -threshold);
}
}
BooleanIndexing.applyWhere(after, absValueCondition, clipFn);
System.out.println(orig);
System.out.println(exp);
System.out.println(after);
assertEquals(exp, after);
}
Aggregations