use of org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative in project nd4j by deeplearning4j.
the class ActivationELU method backprop.
/*
= alpha * exp(x) ; x < 0
f'(x)
= 1 ; x >= 0
*/
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
// no support in ELU native to override alpha
if (alpha != 1.00) {
INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new ELUDerivative(in.dup()));
dLdz.muli(alpha);
BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
} else {
INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new ELUDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
}
}
Aggregations