Search in sources :

Example 1 with NDArrayMessageChunk

use of org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk in project nd4j by deeplearning4j.

the class AeronNDArrayPublisher method publish.

/**
 * Publish an ndarray
 * to an aeron channel
 * @param message
 * @throws Exception
 */
public void publish(NDArrayMessage message) throws Exception {
    if (!init)
        init();
    // Create a context, needed for client connection to media driver
    // A separate media driver process needs to be running prior to starting this application
    // Create an Aeron instance with client-provided context configuration and connect to the
    // media driver, and create a Publication.  The Aeron and Publication classes implement
    // AutoCloseable, and will automatically clean up resources when this try block is finished.
    boolean connected = false;
    if (aeron == null) {
        try {
            while (!connected) {
                aeron = Aeron.connect(ctx);
                connected = true;
            }
        } catch (Exception e) {
            log.warn("Reconnecting on publisher...failed to connect");
        }
    }
    int connectionTries = 0;
    while (publication == null && connectionTries < NUM_RETRIES) {
        try {
            publication = aeron.addPublication(channel, streamId);
            log.info("Created publication on channel " + channel + " and stream " + streamId);
        } catch (DriverTimeoutException e) {
            Thread.sleep(1000 * (connectionTries + 1));
            log.warn("Failed to connect due to driver time out on channel " + channel + " and stream " + streamId + "...retrying in " + connectionTries + " seconds");
            connectionTries++;
        }
    }
    if (!connected && connectionTries >= 3 || publication == null) {
        throw new IllegalStateException("Publisher unable to connect to channel " + channel + " and stream " + streamId);
    }
    // Allocate enough buffer size to hold maximum message length
    // The UnsafeBuffer class is part of the Agrona library and is used for efficient buffer management
    log.info("Publishing to " + channel + " on stream Id " + streamId);
    // ensure default values are set
    INDArray arr = message.getArr();
    if (isCompress())
        while (!message.getArr().isCompressed()) Nd4j.getCompressor().compressi(arr, "GZIP");
    // array is large, need to segment
    if (NDArrayMessage.byteBufferSizeForMessage(message) >= publication.maxMessageLength()) {
        NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128);
        for (int i = 0; i < chunks.length; i++) {
            ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]);
            sendBuff.rewind();
            DirectBuffer buffer = new UnsafeBuffer(sendBuff);
            sendBuffer(buffer);
        }
    } else {
        // send whole array
        DirectBuffer buffer = NDArrayMessage.toBuffer(message);
        sendBuffer(buffer);
    }
}
Also used : DirectBuffer(org.agrona.DirectBuffer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DriverTimeoutException(io.aeron.exceptions.DriverTimeoutException) NDArrayMessageChunk(org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk) UnsafeBuffer(org.agrona.concurrent.UnsafeBuffer) ByteBuffer(java.nio.ByteBuffer) DriverTimeoutException(io.aeron.exceptions.DriverTimeoutException)

Example 2 with NDArrayMessageChunk

use of org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk in project nd4j by deeplearning4j.

the class NDArrayMessage method chunks.

/**
 * Returns an array of
 * message chunks meant to be sent
 * in parallel.
 * Each message chunk has the layout:
 * messageType
 * number of chunks
 * chunkSize
 * length of uuid
 * uuid
 * buffer index
 * actual raw data
 * @param message the message to turn into chunks
 * @param chunkSize the chunk size
 * @return an array of buffers
 */
public static NDArrayMessageChunk[] chunks(NDArrayMessage message, int chunkSize) {
    int numChunks = numChunksForMessage(message, chunkSize);
    NDArrayMessageChunk[] ret = new NDArrayMessageChunk[numChunks];
    DirectBuffer wholeBuffer = NDArrayMessage.toBuffer(message);
    String messageId = UUID.randomUUID().toString();
    for (int i = 0; i < ret.length; i++) {
        // data: only grab a chunk of the data
        ByteBuffer view = (ByteBuffer) wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize);
        view.limit(Math.min(i * chunkSize + chunkSize, wholeBuffer.capacity()));
        view.order(ByteOrder.nativeOrder());
        view = view.slice();
        NDArrayMessageChunk chunk = NDArrayMessageChunk.builder().id(messageId).chunkSize(chunkSize).numChunks(numChunks).messageType(MessageType.CHUNKED).chunkIndex(i).data(view).build();
        // insert in to the array itself
        ret[i] = chunk;
    }
    return ret;
}
Also used : DirectBuffer(org.agrona.DirectBuffer) NDArrayMessageChunk(org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk) ByteBuffer(java.nio.ByteBuffer)

Example 3 with NDArrayMessageChunk

use of org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk in project nd4j by deeplearning4j.

the class NDArrayFragmentHandler method onFragment.

/**
 * Callback for handling
 * fragments of data being read from a log.
 *
 * @param buffer containing the data.
 * @param offset at which the data begins.
 * @param length of the data in bytes.
 * @param header representing the meta data for the data.
 */
@Override
public void onFragment(DirectBuffer buffer, int offset, int length, Header header) {
    ByteBuffer byteBuffer = buffer.byteBuffer();
    boolean byteArrayInput = false;
    if (byteBuffer == null) {
        byteArrayInput = true;
        byte[] destination = new byte[length];
        ByteBuffer wrap = ByteBuffer.wrap(buffer.byteArray());
        wrap.get(destination, offset, length);
        byteBuffer = ByteBuffer.wrap(destination).order(ByteOrder.nativeOrder());
    }
    // only applicable for direct buffers where we don't wrap the array
    if (!byteArrayInput) {
        byteBuffer.position(offset);
        byteBuffer.order(ByteOrder.nativeOrder());
    }
    int messageTypeIndex = byteBuffer.getInt();
    if (messageTypeIndex >= NDArrayMessage.MessageType.values().length)
        throw new IllegalStateException("Illegal index on message opType. Likely corrupt message. Please check the serialization of the bytebuffer. Input was bytebuffer: " + byteArrayInput);
    NDArrayMessage.MessageType messageType = NDArrayMessage.MessageType.values()[messageTypeIndex];
    if (messageType == NDArrayMessage.MessageType.CHUNKED) {
        NDArrayMessageChunk chunk = NDArrayMessageChunk.fromBuffer(byteBuffer, messageType);
        if (chunk.getNumChunks() < 1)
            throw new IllegalStateException("Found invalid number of chunks " + chunk.getNumChunks() + " on chunk index " + chunk.getChunkIndex());
        chunkAccumulator.accumulateChunk(chunk);
        log.info("Number of chunks " + chunk.getNumChunks() + " and number of chunks " + chunk.getNumChunks() + " for id " + chunk.getId() + " is " + chunkAccumulator.numChunksSoFar(chunk.getId()));
        if (chunkAccumulator.allPresent(chunk.getId())) {
            NDArrayMessage message = chunkAccumulator.reassemble(chunk.getId());
            ndArrayCallback.onNDArrayMessage(message);
        }
    } else {
        NDArrayMessage message = NDArrayMessage.fromBuffer(buffer, offset);
        ndArrayCallback.onNDArrayMessage(message);
    }
}
Also used : NDArrayMessageChunk(org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk) ByteBuffer(java.nio.ByteBuffer)

Aggregations

ByteBuffer (java.nio.ByteBuffer)3 NDArrayMessageChunk (org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk)3 DirectBuffer (org.agrona.DirectBuffer)2 DriverTimeoutException (io.aeron.exceptions.DriverTimeoutException)1 UnsafeBuffer (org.agrona.concurrent.UnsafeBuffer)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1