use of org.nd4j.linalg.indexing.conditions.BaseCondition in project nd4j by deeplearning4j.
the class BooleanIndexing method applyWhere.
/**
* This method sets provided number to all elements which match specified condition
*
* @param to
* @param condition
* @param number
*/
public static void applyWhere(final INDArray to, final Condition condition, final Number number) {
if (condition instanceof BaseCondition) {
// for all static conditions we go native
Nd4j.getExecutioner().exec(new CompareAndSet(to, number.doubleValue(), condition));
} else {
final double value = number.doubleValue();
final Function<Number, Number> dynamic = new Function<Number, Number>() {
@Override
public Number apply(Number number) {
return value;
}
};
Shape.iterate(to, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (condition.apply(to.getDouble(coord[0])))
to.putScalar(coord[0], dynamic.apply(to.getDouble(coord[0])).doubleValue());
}
});
}
}
use of org.nd4j.linalg.indexing.conditions.BaseCondition in project nd4j by deeplearning4j.
the class BooleanIndexing method and.
/**
* And over the whole ndarray given some condition
*
* @param n the ndarray to test
* @param cond the condition to test against
* @return true if all of the elements meet the specified
* condition false otherwise
*/
public static boolean and(final INDArray n, final Condition cond) {
if (cond instanceof BaseCondition) {
long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond), Integer.MAX_VALUE).getDouble(0);
if (val == n.lengthLong())
return true;
else
return false;
} else {
boolean ret = true;
final AtomicBoolean a = new AtomicBoolean(ret);
Shape.iterate(n, new CoordinateFunction() {
@Override
public void process(int[]... coord) {
if (a.get())
a.compareAndSet(true, a.get() && cond.apply(n.getDouble(coord[0])));
}
});
return a.get();
}
}
use of org.nd4j.linalg.indexing.conditions.BaseCondition in project nd4j by deeplearning4j.
the class BooleanIndexing method lastIndex.
/**
* This method returns last index matching given condition
*
* PLEASE NOTE: This method will return -1 value if condition wasn't met
*
* @param array
* @param condition
* @return
*/
public static INDArray lastIndex(INDArray array, Condition condition) {
if (!(condition instanceof BaseCondition))
throw new UnsupportedOperationException("Only static Conditions are supported");
LastIndex idx = new LastIndex(array, condition);
Nd4j.getExecutioner().exec(idx);
return Nd4j.scalar((double) idx.getFinalResult());
}
use of org.nd4j.linalg.indexing.conditions.BaseCondition in project nd4j by deeplearning4j.
the class BooleanIndexing method firstIndex.
/**
* This method returns first index matching given condition
*
* PLEASE NOTE: This method will return -1 value if condition wasn't met
*
* @param array
* @param condition
* @return
*/
public static INDArray firstIndex(INDArray array, Condition condition) {
if (!(condition instanceof BaseCondition))
throw new UnsupportedOperationException("Only static Conditions are supported");
FirstIndex idx = new FirstIndex(array, condition);
Nd4j.getExecutioner().exec(idx);
return Nd4j.scalar((double) idx.getFinalResult());
}
use of org.nd4j.linalg.indexing.conditions.BaseCondition in project nd4j by deeplearning4j.
the class BooleanIndexing method and.
/**
* And over the whole ndarray given some condition, with respect to dimensions
*
* @param n the ndarray to test
* @param condition the condition to test against
* @return true if all of the elements meet the specified
* condition false otherwise
*/
public static boolean[] and(final INDArray n, final Condition condition, int... dimension) {
if (!(condition instanceof BaseCondition))
throw new UnsupportedOperationException("Only static Conditions are supported");
MatchCondition op = new MatchCondition(n, condition);
INDArray arr = Nd4j.getExecutioner().exec(op, dimension);
boolean[] result = new boolean[arr.length()];
long tadLength = Shape.getTADLength(n.shape(), dimension);
for (int i = 0; i < arr.length(); i++) {
if (arr.getDouble(i) == tadLength)
result[i] = true;
else
result[i] = false;
}
return result;
}
Aggregations