use of org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan in project nd4j by deeplearning4j.
the class BooleanIndexingTest method testAbsValueGreaterThan.
@Test
public void testAbsValueGreaterThan() {
final double threshold = 2;
Condition absValueCondition = new AbsValueGreaterThan(threshold);
Function<Number, Number> clipFn = new Function<Number, Number>() {
@Override
public Number apply(Number number) {
System.out.println("Number: " + number.doubleValue());
return (number.doubleValue() > threshold ? threshold : -threshold);
}
};
Nd4j.getRandom().setSeed(12345);
// Random numbers: -3 to 3
INDArray orig = Nd4j.rand(1, 20).muli(6).subi(3);
INDArray exp = orig.dup();
INDArray after = orig.dup();
for (int i = 0; i < exp.length(); i++) {
double d = exp.getDouble(i);
if (d > threshold) {
exp.putScalar(i, threshold);
} else if (d < -threshold) {
exp.putScalar(i, -threshold);
}
}
BooleanIndexing.applyWhere(after, absValueCondition, clipFn);
System.out.println(orig);
System.out.println(exp);
System.out.println(after);
assertEquals(exp, after);
}
Aggregations