use of org.nd4j.linalg.compression.CompressedDataBuffer in project nd4j by deeplearning4j.
the class CpuThreshold method compress.
@Override
public DataBuffer compress(DataBuffer buffer) {
INDArray temp = Nd4j.createArrayFromShapeBuffer(buffer, Nd4j.getShapeInfoProvider().createShapeInformation(new int[] { 1, (int) buffer.length() }).getFirst());
MatchCondition condition = new MatchCondition(temp, Conditions.absGreaterThanOrEqual(threshold));
int cntAbs = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
if (cntAbs < 2)
return null;
long originalLength = buffer.length() * Nd4j.sizeOfDataType(buffer.dataType());
int compressedLength = cntAbs + 4;
// first 3 elements contain header
IntPointer pointer = new IntPointer(compressedLength);
pointer.put(0, cntAbs);
pointer.put(1, (int) buffer.length());
pointer.put(2, Float.floatToIntBits(threshold));
pointer.put(3, 0);
CompressionDescriptor descriptor = new CompressionDescriptor();
// sizeOf(INT)
descriptor.setCompressedLength(compressedLength * 4);
descriptor.setOriginalLength(originalLength);
descriptor.setOriginalElementSize(Nd4j.sizeOfDataType(buffer.dataType()));
descriptor.setNumberOfElements(buffer.length());
descriptor.setCompressionAlgorithm(getDescriptor());
descriptor.setCompressionType(getCompressionType());
CompressedDataBuffer cbuff = new CompressedDataBuffer(pointer, descriptor);
Nd4j.getNDArrayFactory().convertDataEx(getBufferTypeEx(buffer), buffer.addressPointer(), DataBuffer.TypeEx.THRESHOLD, pointer, buffer.length());
Nd4j.getAffinityManager().tagLocation(buffer, AffinityManager.Location.HOST);
return cbuff;
}
use of org.nd4j.linalg.compression.CompressedDataBuffer in project nd4j by deeplearning4j.
the class NoOp method decompress.
@Override
public DataBuffer decompress(DataBuffer buffer) {
CompressedDataBuffer comp = (CompressedDataBuffer) buffer;
DataBuffer result = Nd4j.createBuffer(comp.length(), false);
Nd4j.getMemoryManager().memcpy(result, buffer);
return result;
}
use of org.nd4j.linalg.compression.CompressedDataBuffer in project nd4j by deeplearning4j.
the class Float16 method compressPointer.
@Override
protected CompressedDataBuffer compressPointer(DataBuffer.TypeEx srcType, Pointer srcPointer, int length, int elementSize) {
BytePointer ptr = new BytePointer(length * 2);
CompressionDescriptor descriptor = new CompressionDescriptor();
descriptor.setCompressedLength(length * 2);
descriptor.setOriginalLength(length * elementSize);
descriptor.setOriginalElementSize(elementSize);
descriptor.setNumberOfElements(length);
descriptor.setCompressionAlgorithm(getDescriptor());
descriptor.setCompressionType(getCompressionType());
CompressedDataBuffer buffer = new CompressedDataBuffer(ptr, descriptor);
Nd4j.getNDArrayFactory().convertDataEx(srcType, srcPointer, DataBuffer.TypeEx.FLOAT16, ptr, length);
return buffer;
}
use of org.nd4j.linalg.compression.CompressedDataBuffer in project nd4j by deeplearning4j.
the class Gzip method compress.
@Override
public DataBuffer compress(DataBuffer buffer) {
try {
ByteArrayOutputStream stream = new ByteArrayOutputStream();
GZIPOutputStream gzip = new GZIPOutputStream(stream);
DataOutputStream dos = new DataOutputStream(gzip);
buffer.write(dos);
dos.flush();
dos.close();
byte[] bytes = stream.toByteArray();
// logger.info("Bytes: {}", Arrays.toString(bytes));
BytePointer pointer = new BytePointer(bytes);
CompressionDescriptor descriptor = new CompressionDescriptor(buffer, this);
descriptor.setCompressedLength(bytes.length);
CompressedDataBuffer result = new CompressedDataBuffer(pointer, descriptor);
return result;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.nd4j.linalg.compression.CompressedDataBuffer in project nd4j by deeplearning4j.
the class Float8 method compressPointer.
@Override
protected CompressedDataBuffer compressPointer(DataBuffer.TypeEx srcType, Pointer srcPointer, int length, int elementSize) {
BytePointer ptr = new BytePointer(length);
CompressionDescriptor descriptor = new CompressionDescriptor();
descriptor.setCompressedLength(length * 1);
descriptor.setOriginalLength(length * elementSize);
descriptor.setOriginalElementSize(elementSize);
descriptor.setNumberOfElements(length);
descriptor.setCompressionAlgorithm(getDescriptor());
descriptor.setCompressionType(getCompressionType());
CompressedDataBuffer buffer = new CompressedDataBuffer(ptr, descriptor);
Nd4j.getNDArrayFactory().convertDataEx(srcType, srcPointer, DataBuffer.TypeEx.FLOAT8, ptr, length);
return buffer;
}
Aggregations