use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class CudaHalfDataBufferTest method testSerialization2.
@Test
public void testSerialization2() throws Exception {
DataBuffer bufferOriginal = new CudaFloatDataBuffer(new float[] { 1f, 2f, 3f, 4f, 5f });
DataBuffer bufferHalfs = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT, bufferOriginal, DataBuffer.TypeEx.FLOAT16);
DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
File tempFile = File.createTempFile("alpha", "11");
tempFile.deleteOnExit();
// now we serialize halfs, and we expect it to become floats on other side
try (DataOutputStream dos = new DataOutputStream(Files.newOutputStream(Paths.get(tempFile.getAbsolutePath())))) {
bufferHalfs.write(dos);
}
// loading data back from file
DataInputStream dis = new DataInputStream(new FileInputStream(tempFile.getAbsoluteFile()));
DataBuffer bufferRestored = Nd4j.createBuffer(bufferOriginal.length());
bufferRestored.read(dis);
assertEquals(bufferRestored.dataType(), DataBuffer.Type.HALF);
DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
DataBuffer bufferConverted = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT16, bufferRestored, DataBuffer.TypeEx.FLOAT);
assertArrayEquals(bufferOriginal.asFloat(), bufferConverted.asFloat(), 0.01f);
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class CudaHalfDataBufferTest method testConversion1.
@Test
public void testConversion1() throws Exception {
DataBuffer bufferOriginal = new CudaFloatDataBuffer(new float[] { 1f, 2f, 3f, 4f, 5f });
DataBuffer bufferHalfs = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT, bufferOriginal, DataBuffer.TypeEx.FLOAT16);
DataBuffer bufferRestored = Nd4j.getNDArrayFactory().convertDataEx(DataBuffer.TypeEx.FLOAT16, bufferHalfs, DataBuffer.TypeEx.FLOAT);
logger.info("Buffer original: {}", Arrays.toString(bufferOriginal.asFloat()));
logger.info("Buffer restored: {}", Arrays.toString(bufferRestored.asFloat()));
assertArrayEquals(bufferOriginal.asFloat(), bufferRestored.asFloat(), 0.01f);
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class DirectSparseInfoProvider method createSparseInformation.
@Override
public DataBuffer createSparseInformation(int[] flags, long[] sparseOffsets, int[] hiddenDimensions, int underlyingRank) {
SparseDescriptor descriptor = new SparseDescriptor(flags, sparseOffsets, hiddenDimensions, underlyingRank);
if (!sparseCache.containsKey(descriptor)) {
if (counter.get() < MAX_ENTRIES) {
if (!sparseCache.containsKey(descriptor)) {
counter.incrementAndGet();
DataBuffer buffer = Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
sparseCache.put(descriptor, buffer);
return buffer;
}
} else {
return Shape.createSparseInformation(flags, sparseOffsets, hiddenDimensions, underlyingRank);
}
}
return sparseCache.get(descriptor);
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class DirectShapeInfoProvider method createShapeInformation.
@Override
public Pair<DataBuffer, int[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
// We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
offset = 0;
ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);
if (!shapeCache.containsKey(descriptor)) {
if (counter.get() < MAX_ENTRIES) {
synchronized (this) {
if (!shapeCache.containsKey(descriptor)) {
counter.incrementAndGet();
Pair<DataBuffer, int[]> buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
shapeCache.put(descriptor, buffer);
bytes.addAndGet(buffer.getFirst().length() * 4 * 2);
return buffer;
} else
return shapeCache.get(descriptor);
}
} else {
return super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
}
}
return shapeCache.get(descriptor);
}
use of org.nd4j.linalg.api.buffer.DataBuffer in project nd4j by deeplearning4j.
the class ConstantBuffersCache method getConstantBuffer.
@Override
public DataBuffer getConstantBuffer(double[] array) {
ArrayDescriptor descriptor = new ArrayDescriptor(array);
if (!buffersCache.containsKey(descriptor)) {
DataBuffer buffer = Nd4j.createBufferDetached(array);
if (counter.get() < MAX_ENTRIES) {
counter.incrementAndGet();
buffersCache.put(descriptor, buffer);
bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
}
return buffer;
}
return buffersCache.get(descriptor);
}
Aggregations