use of org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative in project nd4j by deeplearning4j.
the class RationalTanhTest method gradientCheck.
@Test
public void gradientCheck() {
double eps = 1e-6;
INDArray A = Nd4j.linspace(-3, 3, 10).reshape(2, 5);
INDArray ADer = Nd4j.getExecutioner().execAndReturn(new RationalTanhDerivative(A.dup()));
double[] a = A.data().asDouble();
double[] aDer = ADer.data().asDouble();
for (int i = 0; i < 10; i++) {
double empirical = (f(a[i] + eps) - f(a[i] - eps)) / (2 * eps);
double analytic = aDer[i];
assertTrue(Math.abs(empirical - analytic) / (Math.abs(empirical) + Math.abs(analytic)) < 0.001);
}
}
use of org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative in project nd4j by deeplearning4j.
the class ActivationRationalTanh method backprop.
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new RationalTanhDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
}
Aggregations