Search in sources :

Example 1 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class BooleanIndexing method applyWhere.

/**
 * This method sets provided number to all elements which match specified condition
 *
 * @param to
 * @param condition
 * @param number
 */
public static void applyWhere(final INDArray to, final Condition condition, final Number number) {
    if (condition instanceof BaseCondition) {
        // for all static conditions we go native
        Nd4j.getExecutioner().exec(new CompareAndSet(to, number.doubleValue(), condition));
    } else {
        final double value = number.doubleValue();
        final Function<Number, Number> dynamic = new Function<Number, Number>() {

            @Override
            public Number apply(Number number) {
                return value;
            }
        };
        Shape.iterate(to, new CoordinateFunction() {

            @Override
            public void process(int[]... coord) {
                if (condition.apply(to.getDouble(coord[0])))
                    to.putScalar(coord[0], dynamic.apply(to.getDouble(coord[0])).doubleValue());
            }
        });
    }
}
Also used : Function(com.google.common.base.Function) CoordinateFunction(org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction) BaseCondition(org.nd4j.linalg.indexing.conditions.BaseCondition) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) IComplexNumber(org.nd4j.linalg.api.complex.IComplexNumber) CoordinateFunction(org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction)

Example 2 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testCaSPairwiseTransform2.

@Test
public void testCaSPairwiseTransform2() throws Exception {
    INDArray x = Nd4j.create(new double[] { 1, 2, 0, 4, 5 });
    INDArray y = Nd4j.create(new double[] { 2, 4, 3, 0, 5 });
    INDArray comp = Nd4j.create(new double[] { 2, 4, 3, 4, 5 });
    Nd4j.getExecutioner().exec(new CompareAndSet(x, y, Conditions.epsNotEquals(0.0)));
    assertEquals(comp, x);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 3 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testConditionalUpdate.

@Test
public void testConditionalUpdate() {
    INDArray arr = Nd4j.linspace(-2, 2, 5);
    INDArray ones = Nd4j.ones(5);
    INDArray exp = Nd4j.create(new double[] { 1, 1, 0, 1, 1 });
    Nd4j.getExecutioner().exec(new CompareAndSet(ones, arr, ones, Conditions.equals(0.0)));
    assertEquals(exp, ones);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 4 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class BooleanIndexingTest method testCaSPairwiseTransform1.

@Test
public void testCaSPairwiseTransform1() throws Exception {
    INDArray array = Nd4j.create(new double[] { 1, 2, 0, 4, 5 });
    INDArray comp = Nd4j.create(new double[] { 1, 2, 3, 4, 5 });
    Nd4j.getExecutioner().exec(new CompareAndSet(array, comp, Conditions.lessThan(5)));
    assertEquals(comp, array);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 5 with CompareAndSet

use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.

the class NativeOpExecutionerTest method testConditionalUpdate.

@Test
public void testConditionalUpdate() {
    INDArray arr = Nd4j.linspace(-2, 2, 5);
    INDArray ones = Nd4j.ones(5);
    System.out.println("arr: " + arr);
    System.out.println("ones: " + ones);
    Nd4j.getExecutioner().exec(new CompareAndSet(ones, arr, ones, Conditions.equals(0.0)));
    System.out.println("After:");
    System.out.println("arr: " + arr);
    System.out.println("ones: " + ones);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CompareAndSet(org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet) Test(org.junit.Test)

Aggregations

CompareAndSet (org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet)10 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 Test (org.junit.Test)7 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)5 Op (org.nd4j.linalg.api.ops.Op)2 Condition (org.nd4j.linalg.indexing.conditions.Condition)2 Function (com.google.common.base.Function)1 IComplexNumber (org.nd4j.linalg.api.complex.IComplexNumber)1 CoordinateFunction (org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction)1 BaseCondition (org.nd4j.linalg.indexing.conditions.BaseCondition)1