Search in sources :

Example 11 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testNonInplaceOp1.

@Test
public void testNonInplaceOp1() throws Exception {
    val arrayX = Nd4j.create(10, 10);
    val arrayY = Nd4j.create(10, 10);
    val arrayZ = Nd4j.create(10, 10);
    arrayX.assign(3.0);
    arrayY.assign(1.0);
    val exp = Nd4j.create(10, 10).assign(4.0);
    CustomOp op = DynamicCustomOp.builder("add").addInputs(arrayX, arrayY).addOutputs(arrayZ).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, arrayZ);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 12 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class CustomOpsTests method testMergeMaxF.

@Test
public void testMergeMaxF() throws Exception {
    // some random array with +ve numbers
    val array0 = Nd4j.rand('f', 5, 2).add(1);
    val array1 = array0.dup('f').add(5);
    // array1 is always bigger than array0 except at 0,0
    array1.put(0, 0, 0);
    // expected value of maxmerge
    val exp = array1.dup('f');
    exp.putScalar(0, 0, array0.getDouble(0, 0));
    val zF = Nd4j.zeros(array0.shape(), 'f');
    CustomOp op = DynamicCustomOp.builder("mergemax").addInputs(array0, array1).addOutputs(zF).build();
    Nd4j.getExecutioner().exec(op);
    assertEquals(exp, zF);
}
Also used : lombok.val(lombok.val) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 13 with CustomOp

use of org.nd4j.linalg.api.ops.CustomOp in project nd4j by deeplearning4j.

the class LossBinaryXENT method scoreArray.

private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
    if (labels.size(1) != preOutput.size(1)) {
        throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") ");
    }
    INDArray scoreArr;
    if (activationFn instanceof ActivationSoftmax) {
        // Use LogSoftMax op to avoid numerical issues when calculating score
        INDArray logsoftmax = Nd4j.getExecutioner().execAndReturn(new LogSoftMax(preOutput.dup()));
        scoreArr = logsoftmax.muli(labels);
    } else {
        // INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(activationFn, preOutput.dup()));
        INDArray output = activationFn.getActivation(preOutput.dup(), true);
        if (clipEps > 0.0) {
            CustomOp op = DynamicCustomOp.builder("clipbyvalue").addInputs(output).callInplace(true).addFloatingPointArguments(clipEps, 1.0 - clipEps).build();
            Nd4j.getExecutioner().exec(op);
        }
        scoreArr = Transforms.log(output, true).muli(labels);
        INDArray secondTerm = output.rsubi(1);
        Transforms.log(secondTerm, false);
        secondTerm.muli(labels.rsub(1));
        scoreArr.addi(secondTerm);
    }
    // Weighted loss function
    if (weights != null) {
        if (weights.length() != preOutput.size(1)) {
            throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + preOutput.size(1));
        }
        scoreArr.muliRowVector(weights);
    }
    if (mask != null) {
        LossUtil.applyMask(scoreArr, mask);
    }
    return scoreArr;
}
Also used : LogSoftMax(org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ActivationSoftmax(org.nd4j.linalg.activations.impl.ActivationSoftmax) CustomOp(org.nd4j.linalg.api.ops.CustomOp) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp)

Aggregations

CustomOp (org.nd4j.linalg.api.ops.CustomOp)13 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)13 lombok.val (lombok.val)11 Test (org.junit.Test)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 Ignore (org.junit.Ignore)1 ActivationSoftmax (org.nd4j.linalg.activations.impl.ActivationSoftmax)1 LogSoftMax (org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax)1 TimesOneMinus (org.nd4j.linalg.api.ops.impl.transforms.TimesOneMinus)1