Search in sources :

Example 1 with ELUDerivative

use of org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative in project nd4j by deeplearning4j.

the class ActivationELU method backprop.

/*
             = alpha * exp(x) ; x < 0
       f'(x)
             = 1 ; x >= 0
     */
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
    // no support in ELU native to override alpha
    if (alpha != 1.00) {
        INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new ELUDerivative(in.dup()));
        dLdz.muli(alpha);
        BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
        dLdz.muli(epsilon);
        return new Pair<>(dLdz, null);
    } else {
        INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new ELUDerivative(in));
        dLdz.muli(epsilon);
        return new Pair<>(dLdz, null);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ELUDerivative(org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative) Pair(org.nd4j.linalg.primitives.Pair)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ELUDerivative (org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative)1 Pair (org.nd4j.linalg.primitives.Pair)1