Search in sources :

Example 1 with AbsValueGreaterThan

use of org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testAbsValueGreaterThan.

@Test
public void testAbsValueGreaterThan() {
    final double threshold = 2;
    Condition absValueCondition = new AbsValueGreaterThan(threshold);
    Function<Number, Number> clipFn = new Function<Number, Number>() {

        @Override
        public Number apply(Number number) {
            System.out.println("Number: " + number.doubleValue());
            return (number.doubleValue() > threshold ? threshold : -threshold);
        }
    };
    Nd4j.getRandom().setSeed(12345);
    // Random numbers: -3 to 3
    INDArray orig = Nd4j.rand(1, 20).muli(6).subi(3);
    INDArray exp = orig.dup();
    INDArray after = orig.dup();
    for (int i = 0; i < exp.length(); i++) {
        double d = exp.getDouble(i);
        if (d > threshold) {
            exp.putScalar(i, threshold);
        } else if (d < -threshold) {
            exp.putScalar(i, -threshold);
        }
    }
    BooleanIndexing.applyWhere(after, absValueCondition, clipFn);
    System.out.println(orig);
    System.out.println(exp);
    System.out.println(after);
    assertEquals(exp, after);
}
Also used : Condition(org.nd4j.linalg.indexing.conditions.Condition) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) Function(com.google.common.base.Function) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AbsValueGreaterThan(org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

Function (com.google.common.base.Function)1 Test (org.junit.Test)1 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 MatchCondition (org.nd4j.linalg.api.ops.impl.accum.MatchCondition)1 AbsValueGreaterThan (org.nd4j.linalg.indexing.conditions.AbsValueGreaterThan)1 Condition (org.nd4j.linalg.indexing.conditions.Condition)1