use of org.nd4j.nativeblas.LongPointerWrapper in project nd4j by deeplearning4j.
the class CpuTADManager method getTADOnlyShapeInfo.
@Override
public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
if (dimension != null && dimension.length > 1)
Arrays.sort(dimension);
if (dimension == null || dimension.length >= 1 && dimension[0] == Integer.MAX_VALUE) {
return new Pair<>(array.shapeInfoDataBuffer(), null);
} else {
TadDescriptor descriptor = new TadDescriptor(array, dimension);
if (!cache.containsKey(descriptor)) {
int dimensionLength = dimension.length;
// FIXME: this is fast triage, remove it later
// dimensionLength <= 1 ? 2 : dimensionLength;
int targetRank = array.rank();
long offsetLength;
long tadLength = 1;
for (int i = 0; i < dimensionLength; i++) {
tadLength *= array.shape()[dimension[i]];
}
offsetLength = array.lengthLong() / tadLength;
DataBuffer outputBuffer = new IntBuffer(targetRank * 2 + 4);
DataBuffer offsetsBuffer = new LongBuffer(offsetLength);
DataBuffer dimensionBuffer = constantHandler.getConstantBuffer(dimension);
Pointer dimensionPointer = dimensionBuffer.addressPointer();
Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
Pointer targetPointer = outputBuffer.addressPointer();
Pointer offsetsPointer = offsetsBuffer.addressPointer();
nativeOps.tadOnlyShapeInfo((IntPointer) xShapeInfo, (IntPointer) dimensionPointer, dimension.length, (IntPointer) targetPointer, new LongPointerWrapper(offsetsPointer));
// If the line below will be uncommented, shapes from JVM will be used on native side
// outputBuffer = array.tensorAlongDimension(0, dimension).shapeInfoDataBuffer();
Pair<DataBuffer, DataBuffer> pair = new Pair<>(outputBuffer, offsetsBuffer);
if (counter.get() < MAX_ENTRIES) {
counter.incrementAndGet();
cache.put(descriptor, pair);
bytes.addAndGet((outputBuffer.length() * 4) + (offsetsBuffer.length() * 8));
}
return pair;
}
return cache.get(descriptor);
}
}
use of org.nd4j.nativeblas.LongPointerWrapper in project nd4j by deeplearning4j.
the class CpuSparseNDArrayFactory method sort.
@Override
public INDArray sort(INDArray x, boolean descending, int... dimension) {
if (x.isScalar())
return x;
Arrays.sort(dimension);
Pair<DataBuffer, DataBuffer> tadBuffers = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(x, dimension);
if (x.data().dataType() == DataBuffer.Type.FLOAT) {
NativeOpsHolder.getInstance().getDeviceNativeOps().sortTadFloat(null, (FloatPointer) x.data().addressPointer(), (IntPointer) x.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length, (IntPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), descending);
} else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
NativeOpsHolder.getInstance().getDeviceNativeOps().sortTadDouble(null, (DoublePointer) x.data().addressPointer(), (IntPointer) x.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length, (IntPointer) tadBuffers.getFirst().addressPointer(), new LongPointerWrapper(tadBuffers.getSecond().addressPointer()), descending);
} else {
throw new UnsupportedOperationException("Unknown datatype " + x.data().dataType());
}
return x;
}
Aggregations