use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class DataSetUtil method merge2d.
/**
* Merge the specified 2d arrays and masks. See {@link #mergeFeatures(INDArray[], INDArray[])}
* and {@link #mergeLabels(INDArray[], INDArray[])}
*
* @param arrays Arrays to merge
* @param masks Mask arrays to merge
* @return Merged arrays and mask
*/
public static Pair<INDArray, INDArray> merge2d(INDArray[] arrays, INDArray[] masks) {
int cols = arrays[0].columns();
INDArray[] temp = new INDArray[arrays.length];
boolean hasMasks = false;
for (int i = 0; i < arrays.length; i++) {
if (arrays[i].columns() != cols) {
throw new IllegalStateException("Cannot merge 2d arrays with different numbers of columns (firstNCols=" + cols + ", ithNCols=" + arrays[i].columns() + ")");
}
temp[i] = arrays[i];
if (masks != null && masks[i] != null && masks[i] != null) {
hasMasks = true;
}
}
INDArray out = Nd4j.specialConcat(0, temp);
INDArray outMask = null;
if (hasMasks) {
outMask = DataSetUtil.mergePerOutputMasks2d(out.shape(), arrays, masks);
}
return new Pair<>(out, outMask);
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method get6dSubArraysWithShape.
public static List<Pair<INDArray, String>> get6dSubArraysWithShape(int seed, int... shape) {
List<Pair<INDArray, String>> list = new ArrayList<>();
String baseMsg = "get6dSubArraysWithShape(" + seed + "," + Arrays.toString(shape) + ")";
// Create and return various sub arrays:
Nd4j.getRandom().setSeed(seed);
int[] newShape1 = Arrays.copyOf(shape, shape.length);
newShape1[0] += 5;
INDArray temp1 = Nd4j.rand(newShape1);
INDArray subset1 = temp1.get(NDArrayIndex.interval(2, shape[0] + 2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset1, baseMsg + ".get(0)"));
int[] newShape2 = Arrays.copyOf(shape, shape.length);
newShape2[1] += 5;
INDArray temp2 = Nd4j.rand(newShape2);
INDArray subset2 = temp2.get(NDArrayIndex.all(), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset2, baseMsg + ".get(1)"));
int[] newShape3 = Arrays.copyOf(shape, shape.length);
newShape3[2] += 5;
INDArray temp3 = Nd4j.rand(newShape3);
INDArray subset3 = temp3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, shape[2] + 4), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset3, baseMsg + ".get(2)"));
int[] newShape4 = Arrays.copyOf(shape, shape.length);
newShape4[3] += 5;
INDArray temp4 = Nd4j.rand(newShape4);
INDArray subset4 = temp4.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(3, shape[3] + 3), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset4, baseMsg + ".get(3)"));
int[] newShape5 = Arrays.copyOf(shape, shape.length);
newShape5[4] += 5;
INDArray temp5 = Nd4j.rand(newShape5);
INDArray subset5 = temp5.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(3, shape[4] + 3), NDArrayIndex.all());
list.add(new Pair<>(subset5, baseMsg + ".get(4)"));
int[] newShape6 = Arrays.copyOf(shape, shape.length);
newShape6[5] += 5;
INDArray temp6 = Nd4j.rand(newShape6);
INDArray subset6 = temp6.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(1, shape[5] + 1));
list.add(new Pair<>(subset6, baseMsg + ".get(5)"));
int[] newShape7 = Arrays.copyOf(shape, shape.length);
newShape7[0] += 5;
newShape7[1] += 5;
newShape7[2] += 5;
newShape7[3] += 5;
newShape7[4] += 5;
newShape7[5] += 5;
INDArray temp7 = Nd4j.rand(newShape7);
INDArray subset7 = temp7.get(NDArrayIndex.interval(4, shape[0] + 4), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.interval(2, shape[2] + 2), NDArrayIndex.interval(1, shape[3] + 1), NDArrayIndex.interval(2, shape[4] + 2), NDArrayIndex.interval(3, shape[5] + 3));
list.add(new Pair<>(subset7, baseMsg + ".get(6)"));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method get4dSubArraysWithShape.
public static List<Pair<INDArray, String>> get4dSubArraysWithShape(int seed, int... shape) {
List<Pair<INDArray, String>> list = new ArrayList<>();
String baseMsg = "get4dSubArraysWithShape(" + seed + "," + Arrays.toString(shape) + ")";
// Create and return various sub arrays:
Nd4j.getRandom().setSeed(seed);
int[] newShape1 = Arrays.copyOf(shape, shape.length);
newShape1[0] += 5;
int len = ArrayUtil.prod(newShape1);
INDArray temp1 = Nd4j.linspace(1, len, len).reshape(newShape1);
INDArray subset1 = temp1.get(NDArrayIndex.interval(2, shape[0] + 2), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset1, baseMsg + ".get(0)"));
int[] newShape2 = Arrays.copyOf(shape, shape.length);
newShape2[1] += 5;
int len2 = ArrayUtil.prod(newShape2);
INDArray temp2 = Nd4j.linspace(1, len2, len2).reshape(newShape2);
INDArray subset2 = temp2.get(NDArrayIndex.all(), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.all(), NDArrayIndex.all());
list.add(new Pair<>(subset2, baseMsg + ".get(1)"));
int[] newShape3 = Arrays.copyOf(shape, shape.length);
newShape3[2] += 5;
int len3 = ArrayUtil.prod(newShape3);
INDArray temp3 = Nd4j.linspace(1, len3, len3).reshape(newShape3);
INDArray subset3 = temp3.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, shape[2] + 4), NDArrayIndex.all());
list.add(new Pair<>(subset3, baseMsg + ".get(2)"));
int[] newShape4 = Arrays.copyOf(shape, shape.length);
newShape4[3] += 5;
int len4 = ArrayUtil.prod(newShape4);
INDArray temp4 = Nd4j.linspace(1, len4, len4).reshape(newShape4);
INDArray subset4 = temp4.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(3, shape[3] + 3));
list.add(new Pair<>(subset4, baseMsg + ".get(3)"));
int[] newShape5 = Arrays.copyOf(shape, shape.length);
newShape5[0] += 5;
newShape5[1] += 5;
newShape5[2] += 5;
newShape5[3] += 5;
int len5 = ArrayUtil.prod(newShape5);
INDArray temp5 = Nd4j.linspace(1, len5, len5).reshape(newShape5);
INDArray subset5 = temp5.get(NDArrayIndex.interval(4, shape[0] + 4), NDArrayIndex.interval(3, shape[1] + 3), NDArrayIndex.interval(2, shape[2] + 2), NDArrayIndex.interval(1, shape[3] + 1));
list.add(new Pair<>(subset5, baseMsg + ".get(4)"));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method getTensorAlongDimensionMatricesWithShape.
public static List<Pair<INDArray, String>> getTensorAlongDimensionMatricesWithShape(char ordering, int rows, int cols, int seed) {
Nd4j.getRandom().setSeed(seed);
// From 3d NDArray: do various tensors. One offset 0, one offset > 0
// [0,1], [0,2], [1,0], [1,2], [2,0], [2,1]
INDArray[] out = new INDArray[12];
INDArray temp01 = Nd4j.linspace(1, cols * rows * 4, cols * rows * 4).reshape(cols, rows, 4);
out[0] = temp01.javaTensorAlongDimension(0, 0, 1).reshape(rows, cols);
int[] temp01Shape = new int[] { cols, rows, 4 };
int len = ArrayUtil.prod(temp01Shape);
temp01 = Nd4j.linspace(1, len, len).reshape(temp01Shape);
out[1] = temp01.javaTensorAlongDimension(2, 0, 1).reshape(rows, cols);
Nd4j.getRandom().setSeed(seed);
INDArray temp02 = Nd4j.linspace(1, len, len).reshape(new int[] { cols, 4, rows });
out[2] = temp02.javaTensorAlongDimension(0, 0, 2).reshape(rows, cols);
temp02 = Nd4j.linspace(1, len, len).reshape(cols, 4, rows);
out[3] = temp02.javaTensorAlongDimension(2, 0, 2).reshape(rows, cols);
INDArray temp10 = Nd4j.linspace(1, len, len).reshape(rows, cols, 4);
out[4] = temp10.javaTensorAlongDimension(0, 1, 0).reshape(rows, cols);
temp10 = Nd4j.linspace(1, len, len).reshape(rows, cols, 4);
out[5] = temp10.javaTensorAlongDimension(2, 1, 0).reshape(rows, cols);
INDArray temp12 = Nd4j.linspace(1, len, len).reshape(4, cols, rows);
out[6] = temp12.javaTensorAlongDimension(0, 1, 2).reshape(rows, cols);
temp12 = Nd4j.linspace(1, len, len).reshape(4, cols, rows);
out[7] = temp12.javaTensorAlongDimension(2, 1, 2).reshape(rows, cols);
INDArray temp20 = Nd4j.linspace(1, len, len).reshape(rows, 4, cols);
out[8] = temp20.javaTensorAlongDimension(0, 2, 0).reshape(rows, cols);
temp20 = Nd4j.linspace(1, len, len).reshape(rows, 4, cols);
out[9] = temp20.javaTensorAlongDimension(2, 2, 0).reshape(rows, cols);
INDArray temp21 = Nd4j.linspace(1, len, len).reshape(4, rows, cols);
out[10] = temp21.javaTensorAlongDimension(0, 2, 1).reshape(rows, cols);
temp21 = Nd4j.linspace(1, len, len).reshape(4, rows, cols);
out[11] = temp21.javaTensorAlongDimension(2, 2, 1).reshape(rows, cols);
String baseMsg = "getTensorAlongDimensionMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
List<Pair<INDArray, String>> list = new ArrayList<>(12);
for (int i = 0; i < out.length; i++) list.add(new Pair<>(out[i], baseMsg + ".get(" + i + ")"));
return list;
}
use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class NDArrayCreationUtil method getPermutedWithShape.
public static Pair<INDArray, String> getPermutedWithShape(char ordering, int rows, int cols, int seed) {
Nd4j.getRandom().setSeed(seed);
int len = rows * cols;
INDArray arr = Nd4j.linspace(1, len, len).reshape(cols, rows);
return new Pair<>(arr.permute(1, 0), "getPermutedWithShape(" + rows + "," + cols + "," + seed + ")");
}
Aggregations