use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.
the class BooleanIndexing method applyWhere.
/**
* This method sets provided number to all elements which match specified condition
*
* @param to
* @param condition
* @param number
*/
public static void applyWhere(final INDArray to, final Condition condition, final Number number) {
if (condition instanceof BaseCondition) {
// for all static conditions we go native
Nd4j.getExecutioner().exec(new CompareAndSet(to, number.doubleValue(), condition));
} else {
final double value = number.doubleValue();
final Function<Number, Number> dynamic = new Function<Number, Number>() {
@Override
public Number apply(Number number) {
return value;
}
};
Shape.iterate(to, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (condition.apply(to.getDouble(coord[0])))
to.putScalar(coord[0], dynamic.apply(to.getDouble(coord[0])).doubleValue());
}
});
}
}
use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.
the class BooleanIndexingTest method testCaSPairwiseTransform2.
@Test
public void testCaSPairwiseTransform2() throws Exception {
INDArray x = Nd4j.create(new double[] { 1, 2, 0, 4, 5 });
INDArray y = Nd4j.create(new double[] { 2, 4, 3, 0, 5 });
INDArray comp = Nd4j.create(new double[] { 2, 4, 3, 4, 5 });
Nd4j.getExecutioner().exec(new CompareAndSet(x, y, Conditions.epsNotEquals(0.0)));
assertEquals(comp, x);
}
use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.
the class BooleanIndexingTest method testConditionalUpdate.
@Test
public void testConditionalUpdate() {
INDArray arr = Nd4j.linspace(-2, 2, 5);
INDArray ones = Nd4j.ones(5);
INDArray exp = Nd4j.create(new double[] { 1, 1, 0, 1, 1 });
Nd4j.getExecutioner().exec(new CompareAndSet(ones, arr, ones, Conditions.equals(0.0)));
assertEquals(exp, ones);
}
use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.
the class BooleanIndexingTest method testCaSPairwiseTransform1.
@Test
public void testCaSPairwiseTransform1() throws Exception {
INDArray array = Nd4j.create(new double[] { 1, 2, 0, 4, 5 });
INDArray comp = Nd4j.create(new double[] { 1, 2, 3, 4, 5 });
Nd4j.getExecutioner().exec(new CompareAndSet(array, comp, Conditions.lessThan(5)));
assertEquals(comp, array);
}
use of org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet in project nd4j by deeplearning4j.
the class NativeOpExecutionerTest method testConditionalUpdate.
@Test
public void testConditionalUpdate() {
INDArray arr = Nd4j.linspace(-2, 2, 5);
INDArray ones = Nd4j.ones(5);
System.out.println("arr: " + arr);
System.out.println("ones: " + ones);
Nd4j.getExecutioner().exec(new CompareAndSet(ones, arr, ones, Conditions.equals(0.0)));
System.out.println("After:");
System.out.println("arr: " + arr);
System.out.println("ones: " + ones);
}
Aggregations