use of edu.iu.dsc.tws.comms.mpi.MPIBuffer in project twister2 by DSC-SPIDAL.
the class BaseLoadBalanceCommunication method init.
@Override
public void init(Config cfg, int containerId, ResourcePlan plan) {
LOG.log(Level.INFO, "Starting the example with container id: " + plan.getThisId());
this.config = cfg;
this.resourcePlan = plan;
this.id = containerId;
this.status = Status.INIT;
this.noOfTasksPerExecutor = NO_OF_TASKS / plan.noOfContainers();
// lets create the task plan
TaskPlan taskPlan = Utils.createReduceTaskPlan(cfg, plan, NO_OF_TASKS);
// first get the communication config file
TWSNetwork network = new TWSNetwork(cfg, taskPlan);
channel = network.getDataFlowTWSCommunication();
Set<Integer> sources = new HashSet<>();
Set<Integer> dests = new HashSet<>();
for (int i = 0; i < NO_OF_TASKS; i++) {
if (i < NO_OF_TASKS / 2) {
sources.add(i);
} else {
dests.add(i);
}
}
LOG.info(String.format("Loadbalance: sources %s destinations: %s", sources, dests));
Map<String, Object> newCfg = new HashMap<>();
LOG.info("Setting up reduce dataflow operation");
// this method calls the init method
// I think this is wrong
loadBalance = channel.loadBalance(newCfg, MessageType.BUFFER, 0, sources, dests, new LoadBalanceReceiver());
// the map thread where data is produced
LOG.info("Starting worker: " + id);
// we need to progress the communication
try {
if (id == 0 || id == 1) {
MPIBuffer data = new MPIBuffer(1024);
data.setSize(24);
for (int i = 0; i < 50000; i++) {
mapFunction(data);
channel.progress();
// we should progress the communication directive
loadBalance.progress();
}
while (true) {
channel.progress();
// we should progress the communication directive
loadBalance.progress();
}
} else {
while (true) {
channel.progress();
// we should progress the communication directive
loadBalance.progress();
}
}
} catch (Throwable t) {
t.printStackTrace();
}
}
use of edu.iu.dsc.tws.comms.mpi.MPIBuffer in project twister2 by DSC-SPIDAL.
the class MPIMessageSerializer method build.
@Override
public Object build(Object message, Object partialBuildObject) {
MPISendMessage sendMessage = (MPISendMessage) partialBuildObject;
// we got an already serialized message, lets just return it
if (sendMessage.getMPIMessage().isComplete()) {
sendMessage.setSendState(MPISendMessage.SendState.SERIALIZED);
return sendMessage;
}
if (sendMessage.getSerializationState() == null) {
sendMessage.setSerializationState(new SerializeState());
}
while (sendBuffers.size() > 0 && sendMessage.serializedState() != MPISendMessage.SendState.SERIALIZED) {
MPIBuffer buffer = sendBuffers.poll();
if (buffer == null) {
break;
}
if (sendMessage.serializedState() == MPISendMessage.SendState.INIT || sendMessage.serializedState() == MPISendMessage.SendState.SENT_INTERNALLY) {
// build the header
buildHeader(buffer, sendMessage);
sendMessage.setSendState(MPISendMessage.SendState.HEADER_BUILT);
}
if (sendMessage.serializedState() == MPISendMessage.SendState.HEADER_BUILT || sendMessage.serializedState() == MPISendMessage.SendState.BODY_BUILT) {
// build the body
// first we need to serialize the body if needed
boolean complete = serializeBody(message, sendMessage, buffer);
if (complete) {
sendMessage.setSendState(MPISendMessage.SendState.SERIALIZED);
}
}
// okay we are adding this buffer
sendMessage.getMPIMessage().addBuffer(buffer);
if (sendMessage.serializedState() == MPISendMessage.SendState.SERIALIZED) {
MPIMessage mpiMessage = sendMessage.getMPIMessage();
SerializeState state = sendMessage.getSerializationState();
int totalBytes = state.getTotalBytes();
mpiMessage.getBuffers().get(0).getByteBuffer().putInt(12, totalBytes);
MessageHeader.Builder builder = MessageHeader.newBuilder(sendMessage.getSource(), sendMessage.getEdge(), totalBytes);
builder.destination(sendMessage.getDestintationIdentifier());
sendMessage.getMPIMessage().setHeader(builder.build());
state.setTotalBytes(0);
// mark the original message as complete
mpiMessage.setComplete(true);
} else {
LOG.fine("Message NOT FULLY serialized");
}
}
return sendMessage;
}
use of edu.iu.dsc.tws.comms.mpi.MPIBuffer in project twister2 by DSC-SPIDAL.
the class MPIMessageSerializer method serializeBuffer.
private boolean serializeBuffer(Object object, MPISendMessage sendMessage, MPIBuffer buffer) {
MPIBuffer dataBuffer = (MPIBuffer) object;
ByteBuffer byteBuffer = buffer.getByteBuffer();
if (sendMessage.serializedState() == MPISendMessage.SendState.HEADER_BUILT) {
// okay we need to serialize the data
// at this point we know the length of the data
byteBuffer.putInt(12, dataBuffer.getSize());
// now lets set the header
MessageHeader.Builder builder = MessageHeader.newBuilder(sendMessage.getSource(), sendMessage.getEdge(), dataBuffer.getSize());
builder.destination(sendMessage.getDestintationIdentifier());
sendMessage.getMPIMessage().setHeader(builder.build());
}
buffer.setSize(16 + dataBuffer.getSize());
// okay we are done with the message
sendMessage.setSendState(MPISendMessage.SendState.SERIALIZED);
return true;
}
use of edu.iu.dsc.tws.comms.mpi.MPIBuffer in project twister2 by DSC-SPIDAL.
the class MPIMultiMessageSerializer method serializeBufferedMessage.
/**
* Serialize a message in buffers.
*
* @return the number of complete messages written
*/
private boolean serializeBufferedMessage(MPIMessage message, SerializeState state, MPIBuffer targetBuffer) {
ByteBuffer targetByteBuffer = targetBuffer.getByteBuffer();
byte[] tempBytes = new byte[targetBuffer.getCapacity()];
// the target remaining space left
int targetRemainingSpace = targetByteBuffer.remaining();
// the current buffer number
int currentSourceBuffer = state.getBufferNo();
// bytes already copied from this buffer
int bytesCopiedFromSource = state.getBytesCopied();
int canCopy = 0;
int needsCopy = 0;
List<MPIBuffer> buffers = message.getBuffers();
MPIBuffer currentMPIBuffer = null;
int totalBytes = state.getTotalBytes();
while (targetRemainingSpace > 0 && currentSourceBuffer < buffers.size()) {
currentMPIBuffer = buffers.get(currentSourceBuffer);
ByteBuffer currentSourceByteBuffer = currentMPIBuffer.getByteBuffer();
// 0th buffer has the header
if (currentSourceBuffer == 0 && bytesCopiedFromSource == 0) {
// we add 16 because,
bytesCopiedFromSource += HEADER_SIZE;
}
needsCopy = currentMPIBuffer.getSize() - bytesCopiedFromSource;
// LOG.info(String.format("%d position %d %d", executor, bytesCopiedFromSource,
// currentSourceByteBuffer.limit()));
currentSourceByteBuffer.position(bytesCopiedFromSource);
canCopy = needsCopy > targetRemainingSpace ? targetRemainingSpace : needsCopy;
currentSourceByteBuffer.get(tempBytes, 0, canCopy);
// todo check this method
targetByteBuffer.put(tempBytes, 0, canCopy);
totalBytes += canCopy;
targetRemainingSpace -= canCopy;
bytesCopiedFromSource += canCopy;
// the target buffer is full, we need to return
if (targetRemainingSpace < NORMAL_SUB_MESSAGE_HEADER_SIZE) {
// now check weather we can move to the next source buffer
if (canCopy == needsCopy) {
currentSourceBuffer++;
bytesCopiedFromSource = 0;
}
break;
}
// if there is space we will copy everything from the source buffer and we need to move
// to next
currentSourceBuffer++;
bytesCopiedFromSource = 0;
}
// set the data size of the target buffer
targetBuffer.setSize(targetByteBuffer.position());
state.setTotalBytes(totalBytes);
if (currentSourceBuffer == buffers.size() && currentMPIBuffer != null) {
state.setBufferNo(0);
state.setBytesCopied(0);
message.release();
return true;
} else {
state.setBufferNo(currentSourceBuffer);
state.setBytesCopied(bytesCopiedFromSource);
return false;
}
}
use of edu.iu.dsc.tws.comms.mpi.MPIBuffer in project twister2 by DSC-SPIDAL.
the class MPIMultiMessageDeserializer method getDataBuffers.
@Override
public Object getDataBuffers(Object partialObject, int edge) {
MPIMessage currentMessage = (MPIMessage) partialObject;
int readLength = 0;
int bufferIndex = 0;
List<MPIBuffer> buffers = currentMessage.getBuffers();
List<Object> returnList = new ArrayList<>();
MessageHeader header = currentMessage.getHeader();
if (header == null) {
throw new RuntimeException("Header must be built before the message");
}
// LOG.info(String.format("%d deserilizing message", executor));
while (readLength < header.getLength()) {
List<MPIBuffer> messageBuffers = new ArrayList<>();
MPIBuffer mpiBuffer = buffers.get(bufferIndex);
ByteBuffer byteBuffer = mpiBuffer.getByteBuffer();
// now read the length
int length = byteBuffer.getInt();
int tempLength = 0;
int tempBufferIndex = bufferIndex;
while (tempLength < length) {
mpiBuffer = buffers.get(tempBufferIndex);
messageBuffers.add(mpiBuffer);
tempLength += mpiBuffer.getByteBuffer().remaining();
tempBufferIndex++;
// LOG.info(String.format("%d temp %d length %d readLength %d header %d buf_pos %d",
// executor, tempLength, length, readLength, header.getLength(),
// mpiBuffer.getByteBuffer().position()));
}
Object object = getSingleDataBuffers(currentMessage, messageBuffers, length);
readLength += length + 4;
if (keyed && !MessageTypeUtils.isPrimitiveType(currentMessage.getKeyType())) {
// adding 4 to the length since the key length is also kept
readLength += 4;
}
byteBuffer = mpiBuffer.getByteBuffer();
if (byteBuffer.remaining() > 0) {
bufferIndex = tempBufferIndex - 1;
} else {
bufferIndex = tempBufferIndex;
}
returnList.add(object);
}
return returnList;
}
Aggregations