use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testNonInplaceOp1.
@Test
public void testNonInplaceOp1() throws Exception {
val arrayX = Nd4j.create(10, 10);
val arrayY = Nd4j.create(10, 10);
val arrayZ = Nd4j.create(10, 10);
arrayX.assign(3.0);
arrayY.assign(1.0);
val exp = Nd4j.create(10, 10).assign(4.0);
CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).addOutputs(arrayZ).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, arrayZ);
}
use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class CustomOpsTests method testMergeMaxF.
@Test
public void testMergeMaxF() throws Exception {
// some random array with +ve numbers
val array0 = Nd4j.rand('f', 5, 2).add(1);
val array1 = array0.dup('f').add(5);
// array1 is always bigger than array0 except at 0,0
array1.put(0, 0, 0);
// expected value of maxmerge
val exp = array1.dup('f');
exp.putScalar(0, 0, array0.getDouble(0, 0));
val zF = Nd4j.zeros(array0.shape(), 'f');
CustomOp op = DynamicCustomOp.builder("mergemax").addInputs(array0, array1).addOutputs(zF).build();
Nd4j.getExecutioner().exec(op);
assertEquals(exp, zF);
}
use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.
the class LossBinaryXENT method scoreArray.
private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
INDArray scoreArr;
if (activationFn instanceof ActivationSoftmax) {
// Use LogSoftMax op to avoid numerical issues when calculating score
INDArray logsoftmax = Nd4j.getExecutioner().execAndReturn(new LogSoftMax(preOutput.dup()));
scoreArr = logsoftmax.muli(labels);
} else {
// INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(activationFn, preOutput.dup()));
INDArray output = activationFn.getActivation(preOutput.dup(), true);
if (clipEps > 0.0) {
CustomOp op = DynamicCustomOp.builder("clipbyvalue").addInputs(output).callInplace(true).addFloatingPointArguments(clipEps, 1.0 - clipEps).build();
Nd4j.getExecutioner().exec(op);
}
scoreArr = Transforms.log(output, true).muli(labels);
INDArray secondTerm = output.rsubi(1);
Transforms.log(secondTerm, false);
secondTerm.muli(labels.rsub(1));
scoreArr.addi(secondTerm);
}
// Weighted loss function
if (weights != null) {
if (weights.length() != preOutput.size(1)) {
throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + preOutput.size(1));
}
scoreArr.muliRowVector(weights);
}
if (mask != null) {
LossUtil.applyMask(scoreArr, mask);
}
return scoreArr;
}
Aggregations