Search in sources :

Example 1 with TanhDerivative

use of org.nd4j.linalg.api.ops.impl.transforms.TanhDerivative in project nd4j by deeplearning4j.

the class ActivationTanH method backprop.

@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
    INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in));
    dLdz.muli(epsilon);
    return new Pair<>(dLdz, null);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) TanhDerivative(org.nd4j.linalg.api.ops.impl.transforms.TanhDerivative) Pair(org.nd4j.linalg.primitives.Pair)

Example 2 with TanhDerivative

use of org.nd4j.linalg.api.ops.impl.transforms.TanhDerivative in project nd4j by deeplearning4j.

the class DerivativeTests method testTanhDerivative.

@Test
public void testTanhDerivative() {
    // Derivative of sigmoid: ds(x)/dx = s(x)*(1-s(x))
    // s(x) = 1 / (exp(-x) + 1)
    INDArray z = Nd4j.zeros(100);
    double[] expOut = new double[100];
    for (int i = 0; i < 100; i++) {
        double x = 0.1 * (i - 50);
        z.putScalar(i, x);
        double tanh = FastMath.tanh(x);
        expOut[i] = 1.0 - tanh * tanh;
    }
    INDArray zPrime = Nd4j.getExecutioner().execAndReturn(new TanhDerivative(z));
    for (int i = 0; i < 100; i++) {
        double relError = Math.abs(expOut[i] - zPrime.getDouble(i)) / (Math.abs(expOut[i]) + Math.abs(zPrime.getDouble(i)));
        assertTrue(relError < REL_ERROR_TOLERANCE);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) TanhDerivative(org.nd4j.linalg.api.ops.impl.transforms.TanhDerivative) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 TanhDerivative (org.nd4j.linalg.api.ops.impl.transforms.TanhDerivative)2 Test (org.junit.Test)1 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)1 Pair (org.nd4j.linalg.primitives.Pair)1