Search in sources :

Example 21 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class LossMixtureDensity method computeGradientAndScore.

@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
    double score = computeScore(labels, preOutput, activationFn, mask, average);
    INDArray gradient = computeGradient(labels, preOutput, activationFn, mask);
    Pair<Double, INDArray> returnCode = new Pair<>(score, gradient);
    return returnCode;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Example 22 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class LossMultiLabel method computeGradientAndScore.

@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
    final INDArray scoreArr = Nd4j.create(labels.size(0), 1);
    final INDArray grad = Nd4j.ones(labels.shape());
    calculate(labels, preOutput, activationFn, mask, scoreArr, grad);
    double score = scoreArr.sumNumber().doubleValue();
    if (average)
        score /= scoreArr.size(0);
    return new Pair<>(score, grad);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Example 23 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class LoneTest method checkSliceofSlice.

@Test
public void checkSliceofSlice() {
    /*
            Issue 1: Slice of slice with c order and f order views are not equal

            Comment out assert and run then -> Issue 2: Index out of bound exception with certain shapes when accessing elements with getDouble() in f order
            (looks like problem is when rank-1==1) eg. 1,2,1 and 2,2,1
         */
    int[] ranksToCheck = new int[] { 2, 3, 4, 5 };
    for (int rank = 0; rank < ranksToCheck.length; rank++) {
        log.info("\nRunning through rank " + ranksToCheck[rank]);
        List<Pair<INDArray, String>> allF = NDArrayCreationUtil.getTestMatricesWithVaryingShapes(ranksToCheck[rank], 'f');
        Iterator<Pair<INDArray, String>> iter = allF.iterator();
        while (iter.hasNext()) {
            Pair<INDArray, String> currentPair = iter.next();
            INDArray origArrayF = currentPair.getFirst();
            INDArray sameArrayC = origArrayF.dup('c');
            log.info("\nLooping through slices for shape " + currentPair.getSecond());
            log.info("\nOriginal array:\n" + origArrayF);
            INDArray viewF = origArrayF.slice(0);
            INDArray viewC = sameArrayC.slice(0);
            log.info("\nSlice 0, C order:\n" + viewC.toString());
            log.info("\nSlice 0, F order:\n" + viewF.toString());
            for (int i = 0; i < viewF.slices(); i++) {
                // assertEquals(viewF.slice(i),viewC.slice(i));
                for (int j = 0; j < viewF.slice(i).length(); j++) {
                    // if (j>0) break;
                    // C order is fine
                    log.info("\nC order slice " + i + ", element 0 :" + viewC.slice(i).getDouble(j));
                    // throws index out of bound err on F order
                    log.info("\nF order slice " + i + ", element 0 :" + viewF.slice(i).getDouble(j));
                }
            }
        }
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair) Test(org.junit.Test)

Example 24 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayTestsFortran method testToOffsetZeroCopy.

@Test
public void testToOffsetZeroCopy() {
    List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123);
    for (Pair<INDArray, String> pair : testInputs) {
        String msg = pair.getSecond();
        INDArray in = pair.getFirst();
        INDArray dup = Shape.toOffsetZeroCopy(in);
        INDArray dupc = Shape.toOffsetZeroCopy(in, 'c');
        INDArray dupf = Shape.toOffsetZeroCopy(in, 'f');
        INDArray dupany = Shape.toOffsetZeroCopyAnyOrder(in);
        assertEquals(msg, in, dup);
        assertEquals(msg, in, dupc);
        assertEquals(msg, in, dupf);
        assertEquals(msg, dupc.ordering(), 'c');
        assertEquals(msg, dupf.ordering(), 'f');
        assertEquals(msg, in, dupany);
        assertEquals(dup.offset(), 0);
        assertEquals(dupc.offset(), 0);
        assertEquals(dupf.offset(), 0);
        assertEquals(dupany.offset(), 0);
        assertEquals(dup.length(), dup.data().length());
        assertEquals(dupc.length(), dupc.data().length());
        assertEquals(dupf.length(), dupf.data().length());
        assertEquals(dupany.length(), dupany.data().length());
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair) Test(org.junit.Test)

Example 25 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayTestsFortran method testDupAndDupWithOrder.

@Test
public void testDupAndDupWithOrder() {
    List<Pair<INDArray, String>> testInputs = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, 123);
    int count = 0;
    for (Pair<INDArray, String> pair : testInputs) {
        String msg = pair.getSecond();
        INDArray in = pair.getFirst();
        System.out.println("Count " + count);
        INDArray dup = in.dup();
        INDArray dupc = in.dup('c');
        INDArray dupf = in.dup('f');
        assertEquals(msg, in, dup);
        assertEquals(msg, dup.ordering(), (char) Nd4j.order());
        assertEquals(msg, dupc.ordering(), 'c');
        assertEquals(msg, dupf.ordering(), 'f');
        assertEquals(msg, in, dupc);
        assertEquals(msg, in, dupf);
        count++;
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair) Test(org.junit.Test)

Aggregations

Pair (org.nd4j.linalg.primitives.Pair)66 INDArray (org.nd4j.linalg.api.ndarray.INDArray)63 Test (org.junit.Test)24 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)9 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 ArrayList (java.util.ArrayList)4 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)4 Ignore (org.junit.Ignore)3 List (java.util.List)2 RealMatrix (org.apache.commons.math3.linear.RealMatrix)2 IntPointer (org.bytedeco.javacpp.IntPointer)2 Pointer (org.bytedeco.javacpp.Pointer)2 OpExecutionerUtil (org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil)2 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)2 IntBuffer (java.nio.IntBuffer)1 Random (java.util.Random)1 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 BlockRealMatrix (org.apache.commons.math3.linear.BlockRealMatrix)1 LUDecomposition (org.apache.commons.math3.linear.LUDecomposition)1