use of org.nd4j.parameterserver.distributed.messages.MeaningfulMessage in project nd4j by deeplearning4j.
the class BaseTransport method clientMessageHandler.
/**
* This message handler is responsible for receiving messages on Client side
* @param buffer
* @param offset
* @param length
* @param header
*/
protected void clientMessageHandler(DirectBuffer buffer, int offset, int length, Header header) {
/**
* All incoming messages here are supposed to be "just messages", only unicast communication
* All of them should implement MeaningfulMessage interface
*/
// TODO: to be implemented
// log.info("clientMessageHandler message request incoming");
byte[] data = new byte[length];
buffer.getBytes(offset, data);
MeaningfulMessage message = (MeaningfulMessage) VoidMessage.fromBytes(data);
completed.put(message.getTaskId(), message);
}
use of org.nd4j.parameterserver.distributed.messages.MeaningfulMessage in project nd4j by deeplearning4j.
the class BaseTransport method sendMessageAndGetResponse.
@Override
public MeaningfulMessage sendMessageAndGetResponse(@NonNull VoidMessage message) {
long startTime = System.currentTimeMillis();
long taskId = message.getTaskId();
sendCommandToShard(message);
AtomicLong cnt = new AtomicLong(0);
// log.info("Sent message to shard: {}, taskId: {}, originalId: {}", message.getClass().getSimpleName(), message.getTaskId(), taskId);
long currentTime = System.currentTimeMillis();
MeaningfulMessage msg;
while ((msg = completed.get(taskId)) == null) {
try {
// Thread.sleep(voidConfiguration.getResponseTimeframe());
feedbackIdler.idle();
if (System.currentTimeMillis() - currentTime > voidConfiguration.getResponseTimeout()) {
log.info("Resending request for taskId [{}]", taskId);
message.incrementRetransmitCount();
// TODO: make retransmit threshold configurable
if (message.getRetransmitCount() > 20)
throw new RuntimeException("Giving up on message delivery...");
return sendMessageAndGetResponse(message);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
completed.remove(taskId);
long endTime = System.currentTimeMillis();
long timeSpent = endTime - startTime;
if (message instanceof Frame && frameCount.incrementAndGet() % 1000 == 0)
log.info("Frame of {} messages [{}] processed in {} ms", ((Frame) message).size(), message.getTaskId(), timeSpent);
return msg;
}
Aggregations