use of org.nd4j.linalg.api.ops.impl.transforms.Assign in project nd4j by deeplearning4j.
the class BaseNDArray method put.
@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for (int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
if (indices.size() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
while (hasNext) {
try {
List<List<Long>> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for (int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
putScalar(curr, element.getDouble(ndIndexIterator.next()));
} catch (NoSuchElementException e) {
hasNext = false;
}
}
} else {
if (indices.size() >= 2) {
while (hasNext) {
try {
List<List<Long>> next = iterate.next();
int[][] nextArr = new int[next.size()][];
for (int i = 0; i < next.size(); i++) {
nextArr[i] = Ints.toArray(next.get(i));
}
int[] curr = Ints.concat(nextArr);
INDArray currSlice = this;
for (int j = 0; j < curr.length; j++) {
currSlice = currSlice.slice(curr[j]);
}
Nd4j.getExecutioner().exec(new Assign(new INDArray[] { currSlice, element }, new INDArray[] { currSlice }));
} catch (NoSuchElementException e) {
hasNext = false;
}
}
}
}
return this;
}
use of org.nd4j.linalg.api.ops.impl.transforms.Assign in project nd4j by deeplearning4j.
the class BaseSparseNDArray method put.
@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
if (indices.size() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for (int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
while (hasNext) {
try {
List<List<Long>> next = iterate.next();
for (int i = 0; i < next.size(); i++) {
int[] curr = Ints.toArray(next.get(i));
putScalar(curr, element.getDouble(ndIndexIterator.next()));
}
} catch (NoSuchElementException e) {
hasNext = false;
}
}
} else {
List<INDArray> arrList = new ArrayList<>();
if (indices.size() >= 2) {
for (int i = 0; i < indices.size(); i++) {
List<Integer> row = indices.get(i);
for (int j = 0; j < row.size(); j++) {
INDArray slice = slice(row.get(j));
Nd4j.getExecutioner().exec(new Assign(new INDArray[] { slice, element }, new INDArray[] { slice }));
arrList.add(slice(row.get(j)));
}
}
} else if (indices.size() == 1) {
for (int i = 0; i < indices.size(); i++) {
arrList.add(slice(indices.get(0).get(i)));
}
}
}
return this;
}
use of org.nd4j.linalg.api.ops.impl.transforms.Assign in project nd4j by deeplearning4j.
the class BaseNDArray method put.
@Override
public INDArray put(INDArray indices, INDArray element) {
if (indices.rank() > 2) {
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
}
if (indices.rows() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
for (int i = 0; i < indices.columns(); i++) {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
putScalar(specifiedIndex, element.getDouble(ndIndexIterator.next()));
}
} else {
List<INDArray> arrList = new ArrayList<>();
if (indices.isMatrix() || indices.isColumnVector()) {
for (int i = 0; i < indices.rows(); i++) {
INDArray row = indices.getRow(i);
for (int j = 0; j < row.length(); j++) {
INDArray slice = slice(row.getInt(j));
Nd4j.getExecutioner().exec(new Assign(new INDArray[] { slice, element }, new INDArray[] { slice }));
arrList.add(slice(row.getInt(j)));
}
}
} else if (indices.isRowVector()) {
for (int i = 0; i < indices.length(); i++) {
arrList.add(slice(indices.getInt(i)));
}
}
}
return this;
}
Aggregations