use of org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk in project nd4j by deeplearning4j.
the class AeronNDArrayPublisher method publish.
/**
* Publish an ndarray
* to an aeron channel
* @param message
* @throws Exception
*/
public void publish(NDArrayMessage message) throws Exception {
if (!init)
init();
// Create a context, needed for client connection to media driver
// A separate media driver process needs to be running prior to starting this application
// Create an Aeron instance with client-provided context configuration and connect to the
// media driver, and create a Publication. The Aeron and Publication classes implement
// AutoCloseable, and will automatically clean up resources when this try block is finished.
boolean connected = false;
if (aeron == null) {
try {
while (!connected) {
aeron = Aeron.connect(ctx);
connected = true;
}
} catch (Exception e) {
log.warn("Reconnecting on publisher...failed to connect");
}
}
int connectionTries = 0;
while (publication == null && connectionTries < NUM_RETRIES) {
try {
publication = aeron.addPublication(channel, streamId);
log.info("Created publication on channel " + channel + " and stream " + streamId);
} catch (DriverTimeoutException e) {
Thread.sleep(1000 * (connectionTries + 1));
log.warn("Failed to connect due to driver time out on channel " + channel + " and stream " + streamId + "...retrying in " + connectionTries + " seconds");
connectionTries++;
}
}
if (!connected && connectionTries >= 3 || publication == null) {
throw new IllegalStateException("Publisher unable to connect to channel " + channel + " and stream " + streamId);
}
// Allocate enough buffer size to hold maximum message length
// The UnsafeBuffer class is part of the Agrona library and is used for efficient buffer management
log.info("Publishing to " + channel + " on stream Id " + streamId);
// ensure default values are set
INDArray arr = message.getArr();
if (isCompress())
while (!message.getArr().isCompressed()) Nd4j.getCompressor().compressi(arr, "GZIP");
// array is large, need to segment
if (NDArrayMessage.byteBufferSizeForMessage(message) >= publication.maxMessageLength()) {
NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128);
for (int i = 0; i < chunks.length; i++) {
ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]);
sendBuff.rewind();
DirectBuffer buffer = new UnsafeBuffer(sendBuff);
sendBuffer(buffer);
}
} else {
// send whole array
DirectBuffer buffer = NDArrayMessage.toBuffer(message);
sendBuffer(buffer);
}
}
use of org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk in project nd4j by deeplearning4j.
the class NDArrayMessage method chunks.
/**
* Returns an array of
* message chunks meant to be sent
* in parallel.
* Each message chunk has the layout:
* messageType
* number of chunks
* chunkSize
* length of uuid
* uuid
* buffer index
* actual raw data
* @param message the message to turn into chunks
* @param chunkSize the chunk size
* @return an array of buffers
*/
public static NDArrayMessageChunk[] chunks(NDArrayMessage message, int chunkSize) {
int numChunks = numChunksForMessage(message, chunkSize);
NDArrayMessageChunk[] ret = new NDArrayMessageChunk[numChunks];
DirectBuffer wholeBuffer = NDArrayMessage.toBuffer(message);
String messageId = UUID.randomUUID().toString();
for (int i = 0; i < ret.length; i++) {
// data: only grab a chunk of the data
ByteBuffer view = (ByteBuffer) wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize);
view.limit(Math.min(i * chunkSize + chunkSize, wholeBuffer.capacity()));
view.order(ByteOrder.nativeOrder());
view = view.slice();
NDArrayMessageChunk chunk = NDArrayMessageChunk.builder().id(messageId).chunkSize(chunkSize).numChunks(numChunks).messageType(MessageType.CHUNKED).chunkIndex(i).data(view).build();
// insert in to the array itself
ret[i] = chunk;
}
return ret;
}
use of org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk in project nd4j by deeplearning4j.
the class NDArrayFragmentHandler method onFragment.
/**
* Callback for handling
* fragments of data being read from a log.
*
* @param buffer containing the data.
* @param offset at which the data begins.
* @param length of the data in bytes.
* @param header representing the meta data for the data.
*/
@Override
public void onFragment(DirectBuffer buffer, int offset, int length, Header header) {
ByteBuffer byteBuffer = buffer.byteBuffer();
boolean byteArrayInput = false;
if (byteBuffer == null) {
byteArrayInput = true;
byte[] destination = new byte[length];
ByteBuffer wrap = ByteBuffer.wrap(buffer.byteArray());
wrap.get(destination, offset, length);
byteBuffer = ByteBuffer.wrap(destination).order(ByteOrder.nativeOrder());
}
// only applicable for direct buffers where we don't wrap the array
if (!byteArrayInput) {
byteBuffer.position(offset);
byteBuffer.order(ByteOrder.nativeOrder());
}
int messageTypeIndex = byteBuffer.getInt();
if (messageTypeIndex >= NDArrayMessage.MessageType.values().length)
throw new IllegalStateException("Illegal index on message opType. Likely corrupt message. Please check the serialization of the bytebuffer. Input was bytebuffer: " + byteArrayInput);
NDArrayMessage.MessageType messageType = NDArrayMessage.MessageType.values()[messageTypeIndex];
if (messageType == NDArrayMessage.MessageType.CHUNKED) {
NDArrayMessageChunk chunk = NDArrayMessageChunk.fromBuffer(byteBuffer, messageType);
if (chunk.getNumChunks() < 1)
throw new IllegalStateException("Found invalid number of chunks " + chunk.getNumChunks() + " on chunk index " + chunk.getChunkIndex());
chunkAccumulator.accumulateChunk(chunk);
log.info("Number of chunks " + chunk.getNumChunks() + " and number of chunks " + chunk.getNumChunks() + " for id " + chunk.getId() + " is " + chunkAccumulator.numChunksSoFar(chunk.getId()));
if (chunkAccumulator.allPresent(chunk.getId())) {
NDArrayMessage message = chunkAccumulator.reassemble(chunk.getId());
ndArrayCallback.onNDArrayMessage(message);
}
} else {
NDArrayMessage message = NDArrayMessage.fromBuffer(buffer, offset);
ndArrayCallback.onNDArrayMessage(message);
}
}
Aggregations