Search in sources :

Example 36 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayCreationUtil method getAll3dTestArraysWithShape.

public static List<Pair<INDArray, String>> getAll3dTestArraysWithShape(int seed, int... shape) {
    if (shape.length != 3)
        throw new IllegalArgumentException("Shape is not length 3");
    List<Pair<INDArray, String>> list = new ArrayList<>();
    String baseMsg = "getAll3dTestArraysWithShape(" + seed + "," + Arrays.toString(shape) + ").get(";
    int len = ArrayUtil.prod(shape);
    // Basic 3d in C and F orders:
    Nd4j.getRandom().setSeed(seed);
    INDArray stdC = Nd4j.linspace(1, len, len).reshape('c', shape);
    INDArray stdF = Nd4j.linspace(1, len, len).reshape('f', shape);
    list.add(new Pair<>(stdC, baseMsg + "0)/Nd4j.linspace(1,len,len)(" + Arrays.toString(shape) + ",'c')"));
    list.add(new Pair<>(stdF, baseMsg + "1)/Nd4j.linspace(1,len,len(" + Arrays.toString(shape) + ",'f')"));
    // Various sub arrays:
    list.addAll(get3dSubArraysWithShape(seed, shape));
    // TAD
    list.addAll(get3dTensorAlongDimensionWithShape(seed, shape));
    // Permuted
    list.addAll(get3dPermutedWithShape(seed, shape));
    // Reshaped
    list.addAll(get3dReshapedWithShape(seed, shape));
    return list;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Example 37 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayCreationUtil method get3dSubArraysWithShape.

public static List<Pair<INDArray, String>> get3dSubArraysWithShape(int seed, int... shape) {
    List<Pair<INDArray, String>> list = new ArrayList<>();
    String baseMsg = "get3dSubArraysWithShape(" + seed + "," + Arrays.toString(shape) + ")";
    // Create and return various sub arrays:
    Nd4j.getRandom().setSeed(seed);
    int[] newShape1 = Arrays.copyOf(shape, shape.length);
    newShape1[0] += 5;
    int len = ArrayUtil.prod(newShape1);
    INDArray temp1 = Nd4j.linspace(1, len, len).reshape(newShape1);
    INDArray subset1 = temp1.get(NDArrayIndex.interval(2, shape[0] + 2), NDArrayIndex.all(), NDArrayIndex.all());
    list.add(new Pair<>(subset1, baseMsg + ".get(0)"));
    int[] newShape2 = Arrays.copyOf(shape, shape.length);
    newShape2[1] += 5;
    int len2 = ArrayUtil.prod(newShape2);
    INDArray temp2 = Nd4j.linspace(1, len2, len2).reshape(newShape2);
    INDArray subset2 = temp2.get(NDArrayIndex.all(), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.all());
    list.add(new Pair<>(subset2, baseMsg + ".get(1)"));
    int[] newShape3 = Arrays.copyOf(shape, shape.length);
    newShape3[2] += 5;
    int len3 = ArrayUtil.prod(newShape3);
    INDArray temp3 = Nd4j.linspace(1, len3, len3).reshape(newShape3);
    INDArray subset3 = temp3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, shape[2] + 4));
    list.add(new Pair<>(subset3, baseMsg + ".get(2)"));
    int[] newShape4 = Arrays.copyOf(shape, shape.length);
    newShape4[0] += 5;
    newShape4[1] += 5;
    newShape4[2] += 5;
    int len4 = ArrayUtil.prod(newShape4);
    INDArray temp4 = Nd4j.linspace(1, len4, len4).reshape(newShape4);
    INDArray subset4 = temp4.get(NDArrayIndex.interval(4, shape[0] + 4), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.interval(2, shape[2] + 2));
    list.add(new Pair<>(subset4, baseMsg + ".get(3)"));
    return list;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Example 38 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayCreationUtil method getAll6dTestArraysWithShape.

public static List<Pair<INDArray, String>> getAll6dTestArraysWithShape(int seed, int... shape) {
    if (shape.length != 6)
        throw new IllegalArgumentException("Shape is not length 6");
    List<Pair<INDArray, String>> list = new ArrayList<>();
    String baseMsg = "getAll6dTestArraysWithShape(" + seed + "," + Arrays.toString(shape) + ").get(";
    // Basic 5d in C and F orders:
    Nd4j.getRandom().setSeed(seed);
    INDArray stdC = Nd4j.rand(shape, 'c');
    INDArray stdF = Nd4j.rand(shape, 'f');
    list.add(new Pair<>(stdC, baseMsg + "0)/Nd4j.rand(" + Arrays.toString(shape) + ",'c')"));
    list.add(new Pair<>(stdF, baseMsg + "1)/Nd4j.rand(" + Arrays.toString(shape) + ",'f')"));
    // Various sub arrays:
    list.addAll(get6dSubArraysWithShape(seed, shape));
    // Permuted
    list.addAll(get6dPermutedWithShape(seed, shape));
    // Reshaped
    list.addAll(get6dReshapedWithShape(seed, shape));
    return list;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Example 39 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayCreationUtil method getSubMatricesWithShape.

public static List<Pair<INDArray, String>> getSubMatricesWithShape(char ordering, int rows, int cols, int seed) {
    // Create 3 identical matrices. Could do get() on single original array, but in-place modifications on one
    // might mess up tests for another
    Nd4j.getRandom().setSeed(seed);
    int[] shape = new int[] { 2 * rows + 4, 2 * cols + 4 };
    int len = ArrayUtil.prod(shape);
    INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
    INDArray first = orig.get(NDArrayIndex.interval(0, rows), NDArrayIndex.interval(0, cols));
    Nd4j.getRandom().setSeed(seed);
    orig = Nd4j.linspace(1, len, len).reshape(shape);
    INDArray second = orig.get(NDArrayIndex.interval(3, rows + 3), NDArrayIndex.interval(3, cols + 3));
    Nd4j.getRandom().setSeed(seed);
    orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
    INDArray third = orig.get(NDArrayIndex.interval(rows, 2 * rows), NDArrayIndex.interval(cols, 2 * cols));
    String baseMsg = "getSubMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
    List<Pair<INDArray, String>> list = new ArrayList<>(3);
    list.add(new Pair<>(first, baseMsg + ".get(0)"));
    list.add(new Pair<>(second, baseMsg + ".get(1)"));
    list.add(new Pair<>(third, baseMsg + ".get(2)"));
    return list;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Example 40 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class NDArrayCreationUtil method getReshapedWithShape.

public static Pair<INDArray, String> getReshapedWithShape(char ordering, int rows, int cols, int seed) {
    Nd4j.getRandom().setSeed(seed);
    int[] origShape = new int[3];
    if (rows % 2 == 0) {
        origShape[0] = rows / 2;
        origShape[1] = cols;
        origShape[2] = 2;
    } else if (cols % 2 == 0) {
        origShape[0] = rows;
        origShape[1] = cols / 2;
        origShape[2] = 2;
    } else {
        origShape[0] = 1;
        origShape[1] = rows;
        origShape[2] = cols;
    }
    int len = ArrayUtil.prod(origShape);
    INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, origShape);
    return new Pair<>(orig.reshape(ordering, rows, cols), "getReshapedWithShape(" + rows + "," + cols + "," + seed + ")");
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair)

Aggregations

Pair (org.nd4j.linalg.primitives.Pair)66 INDArray (org.nd4j.linalg.api.ndarray.INDArray)63 Test (org.junit.Test)24 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)9 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 ArrayList (java.util.ArrayList)4 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)4 Ignore (org.junit.Ignore)3 List (java.util.List)2 RealMatrix (org.apache.commons.math3.linear.RealMatrix)2 IntPointer (org.bytedeco.javacpp.IntPointer)2 Pointer (org.bytedeco.javacpp.Pointer)2 OpExecutionerUtil (org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil)2 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)2 IntBuffer (java.nio.IntBuffer)1 Random (java.util.Random)1 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 BlockRealMatrix (org.apache.commons.math3.linear.BlockRealMatrix)1 LUDecomposition (org.apache.commons.math3.linear.LUDecomposition)1