Search in sources :

Example 1 with Assign

use of org.nd4j.linalg.api.ops.impl.transforms.Assign in project nd4j by deeplearning4j.

the class BaseNDArray method put.

@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
    INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
    for (int i = 0; i < indArrayIndices.length; i++) {
        indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
    }
    boolean hasNext = true;
    Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
    if (indices.size() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        while (hasNext) {
            try {
                List<List<Long>> next = iterate.next();
                int[][] nextArr = new int[next.size()][];
                for (int i = 0; i < next.size(); i++) {
                    nextArr[i] = Ints.toArray(next.get(i));
                }
                int[] curr = Ints.concat(nextArr);
                putScalar(curr, element.getDouble(ndIndexIterator.next()));
            } catch (NoSuchElementException e) {
                hasNext = false;
            }
        }
    } else {
        if (indices.size() >= 2) {
            while (hasNext) {
                try {
                    List<List<Long>> next = iterate.next();
                    int[][] nextArr = new int[next.size()][];
                    for (int i = 0; i < next.size(); i++) {
                        nextArr[i] = Ints.toArray(next.get(i));
                    }
                    int[] curr = Ints.concat(nextArr);
                    INDArray currSlice = this;
                    for (int j = 0; j < curr.length; j++) {
                        currSlice = currSlice.slice(curr[j]);
                    }
                    Nd4j.getExecutioner().exec(new Assign(new INDArray[] { currSlice, element }, new INDArray[] { currSlice }));
                } catch (NoSuchElementException e) {
                    hasNext = false;
                }
            }
        }
    }
    return this;
}
Also used : NdIndexIterator(org.nd4j.linalg.api.iter.NdIndexIterator) Assign(org.nd4j.linalg.api.ops.impl.transforms.Assign)

Example 2 with Assign

use of org.nd4j.linalg.api.ops.impl.transforms.Assign in project nd4j by deeplearning4j.

the class BaseSparseNDArray method put.

@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
    if (indices.size() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
        for (int i = 0; i < indArrayIndices.length; i++) {
            indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
        }
        boolean hasNext = true;
        Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
        while (hasNext) {
            try {
                List<List<Long>> next = iterate.next();
                for (int i = 0; i < next.size(); i++) {
                    int[] curr = Ints.toArray(next.get(i));
                    putScalar(curr, element.getDouble(ndIndexIterator.next()));
                }
            } catch (NoSuchElementException e) {
                hasNext = false;
            }
        }
    } else {
        List<INDArray> arrList = new ArrayList<>();
        if (indices.size() >= 2) {
            for (int i = 0; i < indices.size(); i++) {
                List<Integer> row = indices.get(i);
                for (int j = 0; j < row.size(); j++) {
                    INDArray slice = slice(row.get(j));
                    Nd4j.getExecutioner().exec(new Assign(new INDArray[] { slice, element }, new INDArray[] { slice }));
                    arrList.add(slice(row.get(j)));
                }
            }
        } else if (indices.size() == 1) {
            for (int i = 0; i < indices.size(); i++) {
                arrList.add(slice(indices.get(0).get(i)));
            }
        }
    }
    return this;
}
Also used : NdIndexIterator(org.nd4j.linalg.api.iter.NdIndexIterator) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) ArrayList(java.util.ArrayList) SpecifiedIndex(org.nd4j.linalg.indexing.SpecifiedIndex) ArrayList(java.util.ArrayList) List(java.util.List) Assign(org.nd4j.linalg.api.ops.impl.transforms.Assign) NoSuchElementException(java.util.NoSuchElementException)

Example 3 with Assign

use of org.nd4j.linalg.api.ops.impl.transforms.Assign in project nd4j by deeplearning4j.

the class BaseNDArray method put.

@Override
public INDArray put(INDArray indices, INDArray element) {
    if (indices.rank() > 2) {
        throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
    }
    if (indices.rows() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        for (int i = 0; i < indices.columns(); i++) {
            int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
            putScalar(specifiedIndex, element.getDouble(ndIndexIterator.next()));
        }
    } else {
        List<INDArray> arrList = new ArrayList<>();
        if (indices.isMatrix() || indices.isColumnVector()) {
            for (int i = 0; i < indices.rows(); i++) {
                INDArray row = indices.getRow(i);
                for (int j = 0; j < row.length(); j++) {
                    INDArray slice = slice(row.getInt(j));
                    Nd4j.getExecutioner().exec(new Assign(new INDArray[] { slice, element }, new INDArray[] { slice }));
                    arrList.add(slice(row.getInt(j)));
                }
            }
        } else if (indices.isRowVector()) {
            for (int i = 0; i < indices.length(); i++) {
                arrList.add(slice(indices.getInt(i)));
            }
        }
    }
    return this;
}
Also used : NdIndexIterator(org.nd4j.linalg.api.iter.NdIndexIterator) Assign(org.nd4j.linalg.api.ops.impl.transforms.Assign) ND4JIllegalArgumentException(org.nd4j.linalg.exception.ND4JIllegalArgumentException)

Aggregations

NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)3 Assign (org.nd4j.linalg.api.ops.impl.transforms.Assign)3 ArrayList (java.util.ArrayList)1 List (java.util.List)1 NoSuchElementException (java.util.NoSuchElementException)1 ND4JIllegalArgumentException (org.nd4j.linalg.exception.ND4JIllegalArgumentException)1 INDArrayIndex (org.nd4j.linalg.indexing.INDArrayIndex)1 SpecifiedIndex (org.nd4j.linalg.indexing.SpecifiedIndex)1