Search in sources :

Example 6 with OldSoftMax

use of org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax in project nd4j by deeplearning4j.

the class CudaExecutionerTest method testSoftmax1D_1.

@Test
public void testSoftmax1D_1() throws Exception {
    INDArray input1T = Nd4j.create(new double[] { -0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04 });
    INDArray input1 = Nd4j.create(new double[] { -0.75, 0.58, 0.42, 1.03, -0.61, 0.19, -0.37, -0.40, -1.42, -0.04 });
    INDArray input2 = Nd4j.zerosLike(input1);
    Nd4j.copy(input1, input2);
    INDArray output1 = Nd4j.create(1, 10);
    INDArray output1T = Nd4j.create(1, 10);
    System.out.println("FA --------------------");
    Nd4j.getExecutioner().exec(new OldSoftMax(input1, output1));
    Nd4j.getExecutioner().exec(new OldSoftMax(input1T, output1T));
    System.out.println("FB --------------------");
    System.out.println("Softmax = " + output1);
    INDArray output2 = Nd4j.create(1, 10);
    Nd4j.getExecutioner().exec(new SoftMaxDerivative(input2, output2));
    System.out.println("Softmax Derivative = " + output2);
    INDArray assertion1 = Nd4j.create(new double[] { 0.04, 0.16, 0.14, 0.26, 0.05, 0.11, 0.06, 0.06, 0.02, 0.09 });
    assertArrayEquals(assertion1.data().asFloat(), output1.data().asFloat(), 0.01f);
    assertArrayEquals(assertion1.data().asFloat(), output1T.data().asFloat(), 0.01f);
}
Also used : OldSoftMax(org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SoftMaxDerivative(org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative) Test(org.junit.Test)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 OldSoftMax (org.nd4j.linalg.api.ops.impl.transforms.OldSoftMax)6 Test (org.junit.Test)3 SoftMaxDerivative (org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative)2 ManhattanDistance (org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance)1 IMax (org.nd4j.linalg.api.ops.impl.indexaccum.IMax)1 LogSoftMax (org.nd4j.linalg.api.ops.impl.transforms.LogSoftMax)1 Sqrt (org.nd4j.linalg.api.ops.impl.transforms.Sqrt)1 Pair (org.nd4j.linalg.primitives.Pair)1