Search in sources :

Example 1 with DummyWorkspace

use of org.nd4j.linalg.memory.abstracts.DummyWorkspace in project nd4j by deeplearning4j.

the class BaseCudaDataBuffer method read.

@Override
public void read(DataInputStream s) {
    try {
        // log.info("Restoring CUDA databuffer");
        // skip allocationMode
        s.readUTF();
        allocationMode = AllocationMode.JAVACPP;
        int locLength = s.readInt();
        boolean reallocate = locLength != length || indexer == null;
        length = locLength;
        Type t = Type.valueOf(s.readUTF());
        // log.info("Restoring buffer ["+t+"] of length ["+ length+"]");
        if (globalType == null && Nd4j.dataType() != null) {
            globalType = Nd4j.dataType();
        }
        if (t != globalType && t != Type.INT && Nd4j.sizeOfDataType(globalType) < Nd4j.sizeOfDataType(t)) {
            log.warn("Loading a data stream with opType different from what is set globally. Expect precision loss");
            if (globalType == Type.INT)
                log.warn("Int to float/double widening UNSUPPORTED!!!");
        }
        if (t == Type.COMPRESSED) {
            type = t;
            return;
        } else if (t == Type.INT || globalType == Type.INT) {
            this.elementSize = 4;
            this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, t), false);
            this.trackingPoint = allocationPoint.getObjectId();
            // we keep int buffer's dtype after ser/de
            this.type = t;
            this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asIntPointer();
            indexer = IntIndexer.create((IntPointer) pointer);
            IntIndexer Iindexer = (IntIndexer) indexer;
            int[] array = new int[(int) length];
            for (int i = 0; i < length(); i++) {
                if (t == Type.INT)
                    // array[i] = s.readInt();
                    Iindexer.put(i, s.readInt());
                else if (t == Type.DOUBLE)
                    Iindexer.put(i, (int) s.readDouble());
                else if (t == Type.FLOAT)
                    Iindexer.put(i, (int) s.readFloat());
                else if (t == Type.HALF)
                    Iindexer.put(i, (int) toFloat((int) s.readShort()));
            }
            allocationPoint.tickHostWrite();
        } else if (globalType == Type.DOUBLE) {
            this.elementSize = 8;
            if (reallocate) {
                MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
                if (workspace != null && (workspace instanceof DummyWorkspace)) {
                    this.attached = true;
                    this.parentWorkspace = workspace;
                    workspaceGenerationId = workspace.getGenerationId();
                }
                this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, globalType), false);
                // allocationPoint.attachBuffer(this);
                this.trackingPoint = allocationPoint.getObjectId();
                this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asDoublePointer();
                indexer = DoubleIndexer.create((DoublePointer) pointer);
            }
            DoubleIndexer Dindexer = (DoubleIndexer) indexer;
            for (int i = 0; i < length(); i++) {
                if (t == Type.DOUBLE)
                    Dindexer.put(i, s.readDouble());
                else if (t == Type.FLOAT)
                    Dindexer.put(i, (double) s.readFloat());
                else if (t == Type.HALF)
                    Dindexer.put(i, (double) toFloat((int) s.readShort()));
            }
            allocationPoint.tickHostWrite();
        } else if (globalType == Type.FLOAT) {
            this.elementSize = 4;
            if (reallocate) {
                this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
                this.trackingPoint = allocationPoint.getObjectId();
                this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asFloatPointer();
                indexer = FloatIndexer.create((FloatPointer) pointer);
            }
            FloatIndexer Findexer = (FloatIndexer) indexer;
            for (int i = 0; i < length; i++) {
                if (t == Type.DOUBLE)
                    Findexer.put(i, (float) s.readDouble());
                else if (t == Type.FLOAT)
                    Findexer.put(i, s.readFloat());
                else if (t == Type.HALF) {
                    Findexer.put(i, toFloat((int) s.readShort()));
                }
            }
            allocationPoint.tickHostWrite();
        } else if (globalType == Type.HALF) {
            this.elementSize = 2;
            if (reallocate) {
                this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
                this.trackingPoint = allocationPoint.getObjectId();
                this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length).asShortPointer();
                indexer = HalfIndexer.create((ShortPointer) this.pointer);
            }
            HalfIndexer Hindexer = (HalfIndexer) indexer;
            for (int i = 0; i < length; i++) {
                if (t == Type.DOUBLE)
                    Hindexer.put(i, (float) s.readDouble());
                else if (t == Type.FLOAT)
                    Hindexer.put(i, s.readFloat());
                else if (t == Type.HALF) {
                    Hindexer.put(i, toFloat((int) s.readShort()));
                }
            }
            // for HALF & HALF2 datatype we just tag data as fresh on host
            allocationPoint.tickHostWrite();
        } else
            throw new IllegalStateException("Unknown dataType: [" + t.toString() + "]");
    /*
            this.wrappedBuffer = this.pointer.asByteBuffer();
            this.wrappedBuffer.order(ByteOrder.nativeOrder());
            */
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    // we call sync to copyback data to host
    AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(allocationPoint);
// allocator.synchronizeHostData(this);
}
Also used : DummyWorkspace(org.nd4j.linalg.memory.abstracts.DummyWorkspace) AllocationShape(org.nd4j.jita.allocator.impl.AllocationShape) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) IOException(java.io.IOException) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Aggregations

IOException (java.io.IOException)1 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)1 AllocationShape (org.nd4j.jita.allocator.impl.AllocationShape)1 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)1 MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)1 DummyWorkspace (org.nd4j.linalg.memory.abstracts.DummyWorkspace)1