use of org.nd4j.linalg.memory.abstracts.DummyWorkspace in project nd4j by deeplearning4j.
the class BaseCudaDataBuffer method read.
@Override
public void read(DataInputStream s) {
try {
// log.info("Restoring CUDA databuffer");
// skip allocationMode
s.readUTF();
allocationMode = AllocationMode.JAVACPP;
int locLength = s.readInt();
boolean reallocate = locLength != length || indexer == null;
length = locLength;
Type t = Type.valueOf(s.readUTF());
// log.info("Restoring buffer ["+t+"] of length ["+ length+"]");
if (globalType == null && Nd4j.dataType() != null) {
globalType = Nd4j.dataType();
}
if (t != globalType && t != Type.INT && Nd4j.sizeOfDataType(globalType) < Nd4j.sizeOfDataType(t)) {
log.warn("Loading a data stream with opType different from what is set globally. Expect precision loss");
if (globalType == Type.INT)
log.warn("Int to float/double widening UNSUPPORTED!!!");
}
if (t == Type.COMPRESSED) {
type = t;
return;
} else if (t == Type.INT || globalType == Type.INT) {
this.elementSize = 4;
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, t), false);
this.trackingPoint = allocationPoint.getObjectId();
// we keep int buffer's dtype after ser/de
this.type = t;
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
IntIndexer Iindexer = (IntIndexer) indexer;
int[] array = new int[(int) length];
for (int i = 0; i < length(); i++) {
if (t == Type.INT)
// array[i] = s.readInt();
Iindexer.put(i, s.readInt());
else if (t == Type.DOUBLE)
Iindexer.put(i, (int) s.readDouble());
else if (t == Type.FLOAT)
Iindexer.put(i, (int) s.readFloat());
else if (t == Type.HALF)
Iindexer.put(i, (int) toFloat((int) s.readShort()));
}
allocationPoint.tickHostWrite();
} else if (globalType == Type.DOUBLE) {
this.elementSize = 8;
if (reallocate) {
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
if (workspace != null && (workspace instanceof DummyWorkspace)) {
this.attached = true;
this.parentWorkspace = workspace;
workspaceGenerationId = workspace.getGenerationId();
}
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, globalType), false);
// allocationPoint.attachBuffer(this);
this.trackingPoint = allocationPoint.getObjectId();
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
}
DoubleIndexer Dindexer = (DoubleIndexer) indexer;
for (int i = 0; i < length(); i++) {
if (t == Type.DOUBLE)
Dindexer.put(i, s.readDouble());
else if (t == Type.FLOAT)
Dindexer.put(i, (double) s.readFloat());
else if (t == Type.HALF)
Dindexer.put(i, (double) toFloat((int) s.readShort()));
}
allocationPoint.tickHostWrite();
} else if (globalType == Type.FLOAT) {
this.elementSize = 4;
if (reallocate) {
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
this.trackingPoint = allocationPoint.getObjectId();
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
}
FloatIndexer Findexer = (FloatIndexer) indexer;
for (int i = 0; i < length; i++) {
if (t == Type.DOUBLE)
Findexer.put(i, (float) s.readDouble());
else if (t == Type.FLOAT)
Findexer.put(i, s.readFloat());
else if (t == Type.HALF) {
Findexer.put(i, toFloat((int) s.readShort()));
}
}
allocationPoint.tickHostWrite();
} else if (globalType == Type.HALF) {
this.elementSize = 2;
if (reallocate) {
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
this.trackingPoint = allocationPoint.getObjectId();
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) this.pointer);
}
HalfIndexer Hindexer = (HalfIndexer) indexer;
for (int i = 0; i < length; i++) {
if (t == Type.DOUBLE)
Hindexer.put(i, (float) s.readDouble());
else if (t == Type.FLOAT)
Hindexer.put(i, s.readFloat());
else if (t == Type.HALF) {
Hindexer.put(i, toFloat((int) s.readShort()));
}
}
// for HALF & HALF2 datatype we just tag data as fresh on host
allocationPoint.tickHostWrite();
} else
throw new IllegalStateException("Unknown dataType: [" + t.toString() + "]");
/*
this.wrappedBuffer = this.pointer.asByteBuffer();
this.wrappedBuffer.order(ByteOrder.nativeOrder());
*/
} catch (Exception e) {
throw new RuntimeException(e);
}
// we call sync to copyback data to host
AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(allocationPoint);
// allocator.synchronizeHostData(this);
}
Aggregations