use of edu.iu.dsc.tws.api.comms.packing.DataBuffer in project twister2 by DSC-SPIDAL.
the class TWSMPIChannel method postReceive.
private void postReceive(MPIReceiveRequests requests) {
DataBuffer byteBuffer = requests.availableBuffers.poll();
while (byteBuffer != null) {
// post the receive
pendingReceiveCount++;
Request request = postReceive(requests.rank, requests.edge, byteBuffer);
requests.pendingRequests.add(new MPIRequest(request, byteBuffer));
byteBuffer = requests.availableBuffers.poll();
}
}
use of edu.iu.dsc.tws.api.comms.packing.DataBuffer in project twister2 by DSC-SPIDAL.
the class TWSTCPChannel method postReceive.
private void postReceive(TCPReceiveRequests requests) {
DataBuffer byteBuffer = requests.availableBuffers.poll();
if (byteBuffer != null) {
// post the receive
TCPMessage request = postReceive(requests.rank, requests.edge, byteBuffer);
requests.pendingRequests.add(new Request(request, byteBuffer));
}
}
use of edu.iu.dsc.tws.api.comms.packing.DataBuffer in project twister2 by DSC-SPIDAL.
the class TWSTCPChannel method postMessage.
/**
* Send a message to the given rank.
*
* @param requests the message
*/
private void postMessage(TCPSendRequests requests) {
ChannelMessage message = requests.message;
for (int i = 0; i < message.getNormalBuffers().size(); i++) {
sendCount++;
DataBuffer buffer = message.getNormalBuffers().get(i);
TCPMessage request = channel.iSend(buffer.getByteBuffer(), buffer.getSize(), requests.rank, message.getHeader().getEdge());
// register to the loop to make communicationProgress on the send
requests.pendingSends.add(new Request(request, buffer));
}
}
use of edu.iu.dsc.tws.api.comms.packing.DataBuffer in project twister2 by DSC-SPIDAL.
the class ControlledChannelOperation method setupReceiveGroups.
/**
* Start receiving from the next set of ids
*/
public void setupReceiveGroups(List<IntArrayList> receivingIds) {
this.receiveIdGroups = receivingIds;
int max = Integer.MIN_VALUE;
// first lets validate
for (int i = 0; i < receivingIds.size(); i++) {
List<Integer> group = receivingIds.get(i);
if (group.size() > max) {
max = group.size();
}
}
// we put max group size equal buffers
int receiveBufferSize = CommunicationContext.bufferSize(config);
this.freeReceiveBuffers = new ArrayBlockingQueue<>(max);
for (int i = 0; i < max; i++) {
ByteBuffer byteBuffer = channel.createBuffer(receiveBufferSize);
this.freeReceiveBuffers.offer(new DataBuffer(byteBuffer));
}
}
use of edu.iu.dsc.tws.api.comms.packing.DataBuffer in project twister2 by DSC-SPIDAL.
the class ControlledChannelOperation method receiveProgress.
/**
* Progress the receive
*
* @param receiveId
*/
public void receiveProgress(int receiveId) {
Queue<InMessage> pendingReceiveMessages = pendingReceiveMessagesPerSource.get(receiveId);
boolean canProgress = true;
while (pendingReceiveMessages.size() > 0 && canProgress) {
InMessage currentMessage = pendingReceiveMessages.peek();
lock.lock();
try {
// lets keep track that we have completed a receive from this executor
int workerId = currentMessage.getOriginatingId();
int count = currentReceives.get(workerId);
int expected = expectedReceivePerWorker.get(workerId);
InMessage.ReceivedState receivedState = currentMessage.getReceivedState();
if (receivedState == InMessage.ReceivedState.BUILT) {
// if this message is built, we need to check how many we are expecting
count++;
currentReceives.put(workerId, count);
}
if (receivedState == InMessage.ReceivedState.BUILDING || receivedState == InMessage.ReceivedState.BUILT) {
while (currentMessage.getBuiltMessages().size() > 0) {
// get the first channel message
ChannelMessage msg = currentMessage.getBuiltMessages().peek();
if (msg != null) {
if (!receiver.handleReceivedChannelMessage(msg)) {
canProgress = false;
break;
}
ChannelMessage releaseMsg = currentMessage.getBuiltMessages().poll();
Objects.requireNonNull(releaseMsg).release();
if (receivedState == InMessage.ReceivedState.BUILDING) {
DataBuffer buffer = freeReceiveBuffers.poll();
Queue<DataBuffer> list = receiveBuffers.get(workerId);
if (buffer == null) {
throw new RuntimeException("Free buffers doesn't have any buffer");
}
// we get a free buffer and offer
list.offer(buffer);
} else {
if (expected > count) {
DataBuffer buffer = freeReceiveBuffers.poll();
Queue<DataBuffer> list = receiveBuffers.get(workerId);
if (buffer == null) {
throw new RuntimeException("Free buffers doesn't have any buffer");
}
// we get a free buffer and offer
list.offer(buffer);
}
}
}
}
if (receivedState == InMessage.ReceivedState.BUILT && currentMessage.getBuiltMessages().size() == 0 && canProgress) {
currentMessage.setReceivedState(InMessage.ReceivedState.RECEIVE);
}
}
if (receivedState == InMessage.ReceivedState.RECEIVE) {
Object object = currentMessage.getDeserializedData();
if (!receiver.receiveMessage(currentMessage.getHeader(), object)) {
break;
}
currentMessage.setReceivedState(InMessage.ReceivedState.DONE);
pendingReceiveMessages.poll();
} else {
break;
}
} finally {
lock.unlock();
}
}
}
Aggregations