use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method getAll3dTestArraysWithShape.
public static List<Pair<INDArray, String>> getAll3dTestArraysWithShape(int seed, int... shape) {
if (shape.length != 3)
throw new IllegalArgumentException("Shape is not length 3");
List<Pair<INDArray, String>> list = new ArrayList<>();
String baseMsg = "getAll3dTestArraysWithShape(" + seed + "," + Arrays.toString(shape) + ").get(";
int len = ArrayUtil.prod(shape);
// Basic 3d in C and F orders:
Nd4j.getRandom().setSeed(seed);
INDArray stdC = Nd4j.linspace(1, len, len).reshape('c', shape);
INDArray stdF = Nd4j.linspace(1, len, len).reshape('f', shape);
list.add(new Pair<>(stdC, baseMsg + "0)/Nd4j.linspace(1,len,len)(" + Arrays.toString(shape) + ",'c')"));
list.add(new Pair<>(stdF, baseMsg + "1)/Nd4j.linspace(1,len,len(" + Arrays.toString(shape) + ",'f')"));
// Various sub arrays:
list.addAll(get3dSubArraysWithShape(seed, shape));
// TAD
list.addAll(get3dTensorAlongDimensionWithShape(seed, shape));
// Permuted
list.addAll(get3dPermutedWithShape(seed, shape));
// Reshaped
list.addAll(get3dReshapedWithShape(seed, shape));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method get3dSubArraysWithShape.
public static List<Pair<INDArray, String>> get3dSubArraysWithShape(int seed, int... shape) {
List<Pair<INDArray, String>> list = new ArrayList<>();
String baseMsg = "get3dSubArraysWithShape(" + seed + "," + Arrays.toString(shape) + ")";
// Create and return various sub arrays:
Nd4j.getRandom().setSeed(seed);
int[] newShape1 = Arrays.copyOf(shape, shape.length);
newShape1[0] += 5;
int len = ArrayUtil.prod(newShape1);
INDArray temp1 = Nd4j.linspace(1, len, len).reshape(newShape1);
INDArray subset1 = temp1.get(NDArrayIndex.interval(2, shape[0] + 2), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset1, baseMsg + ".get(0)"));
int[] newShape2 = Arrays.copyOf(shape, shape.length);
newShape2[1] += 5;
int len2 = ArrayUtil.prod(newShape2);
INDArray temp2 = Nd4j.linspace(1, len2, len2).reshape(newShape2);
INDArray subset2 = temp2.get(NDArrayIndex.all(), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.all());
list.add(new Pair<>(subset2, baseMsg + ".get(1)"));
int[] newShape3 = Arrays.copyOf(shape, shape.length);
newShape3[2] += 5;
int len3 = ArrayUtil.prod(newShape3);
INDArray temp3 = Nd4j.linspace(1, len3, len3).reshape(newShape3);
INDArray subset3 = temp3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, shape[2] + 4));
list.add(new Pair<>(subset3, baseMsg + ".get(2)"));
int[] newShape4 = Arrays.copyOf(shape, shape.length);
newShape4[0] += 5;
newShape4[1] += 5;
newShape4[2] += 5;
int len4 = ArrayUtil.prod(newShape4);
INDArray temp4 = Nd4j.linspace(1, len4, len4).reshape(newShape4);
INDArray subset4 = temp4.get(NDArrayIndex.interval(4, shape[0] + 4), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.interval(2, shape[2] + 2));
list.add(new Pair<>(subset4, baseMsg + ".get(3)"));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method getAll6dTestArraysWithShape.
public static List<Pair<INDArray, String>> getAll6dTestArraysWithShape(int seed, int... shape) {
if (shape.length != 6)
throw new IllegalArgumentException("Shape is not length 6");
List<Pair<INDArray, String>> list = new ArrayList<>();
String baseMsg = "getAll6dTestArraysWithShape(" + seed + "," + Arrays.toString(shape) + ").get(";
// Basic 5d in C and F orders:
Nd4j.getRandom().setSeed(seed);
INDArray stdC = Nd4j.rand(shape, 'c');
INDArray stdF = Nd4j.rand(shape, 'f');
list.add(new Pair<>(stdC, baseMsg + "0)/Nd4j.rand(" + Arrays.toString(shape) + ",'c')"));
list.add(new Pair<>(stdF, baseMsg + "1)/Nd4j.rand(" + Arrays.toString(shape) + ",'f')"));
// Various sub arrays:
list.addAll(get6dSubArraysWithShape(seed, shape));
// Permuted
list.addAll(get6dPermutedWithShape(seed, shape));
// Reshaped
list.addAll(get6dReshapedWithShape(seed, shape));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method getSubMatricesWithShape.
public static List<Pair<INDArray, String>> getSubMatricesWithShape(char ordering, int rows, int cols, int seed) {
// Create 3 identical matrices. Could do get() on single original array, but in-place modifications on one
// might mess up tests for another
Nd4j.getRandom().setSeed(seed);
int[] shape = new int[] { 2 * rows + 4, 2 * cols + 4 };
int len = ArrayUtil.prod(shape);
INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
INDArray first = orig.get(NDArrayIndex.interval(0, rows), NDArrayIndex.interval(0, cols));
Nd4j.getRandom().setSeed(seed);
orig = Nd4j.linspace(1, len, len).reshape(shape);
INDArray second = orig.get(NDArrayIndex.interval(3, rows + 3), NDArrayIndex.interval(3, cols + 3));
Nd4j.getRandom().setSeed(seed);
orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
INDArray third = orig.get(NDArrayIndex.interval(rows, 2 * rows), NDArrayIndex.interval(cols, 2 * cols));
String baseMsg = "getSubMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
List<Pair<INDArray, String>> list = new ArrayList<>(3);
list.add(new Pair<>(first, baseMsg + ".get(0)"));
list.add(new Pair<>(second, baseMsg + ".get(1)"));
list.add(new Pair<>(third, baseMsg + ".get(2)"));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method getReshapedWithShape.
public static Pair<INDArray, String> getReshapedWithShape(char ordering, int rows, int cols, int seed) {
Nd4j.getRandom().setSeed(seed);
int[] origShape = new int[3];
if (rows % 2 == 0) {
origShape[0] = rows / 2;
origShape[1] = cols;
origShape[2] = 2;
} else if (cols % 2 == 0) {
origShape[0] = rows;
origShape[1] = cols / 2;
origShape[2] = 2;
} else {
origShape[0] = 1;
origShape[1] = rows;
origShape[2] = cols;
}
int len = ArrayUtil.prod(origShape);
INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, origShape);
return new Pair<>(orig.reshape(ordering, rows, cols), "getReshapedWithShape(" + rows + "," + cols + "," + seed + ")");
}
Aggregations