use of org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation in project nd4j by deeplearning4j.
the class ClipboardTest method testPin1.
@Test
public void testPin1() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
for (int i = 0; i < 100; i++) {
VectorAggregation aggregation = new VectorAggregation(rng.nextLong(), (short) 100, (short) i, Nd4j.create(5));
clipboard.pin(aggregation);
}
assertEquals(false, clipboard.hasCandidates());
assertEquals(0, clipboard.getNumberOfCompleteStacks());
assertEquals(100, clipboard.getNumberOfPinnedStacks());
}
use of org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation in project nd4j by deeplearning4j.
the class ClipboardTest method testPin2.
@Test
public void testPin2() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
Long validId = 123L;
short shardIdx = 0;
for (int i = 0; i < 300; i++) {
VectorAggregation aggregation = new VectorAggregation(rng.nextLong(), (short) 100, (short) 1, Nd4j.create(5));
// imitating valid
if (i % 2 == 0 && shardIdx < 100) {
aggregation.setTaskId(validId);
aggregation.setShardIndex(shardIdx++);
}
clipboard.pin(aggregation);
}
VoidAggregation aggregation = clipboard.getStackFromClipboard(0L, validId);
assertNotEquals(null, aggregation);
assertEquals(0, aggregation.getMissingChunks());
assertEquals(true, clipboard.hasCandidates());
assertEquals(1, clipboard.getNumberOfCompleteStacks());
}
use of org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation in project nd4j by deeplearning4j.
the class DistributedVectorMessage method processMessage.
/**
* This method will be started in context of executor, either Shard, Client or Backup node
*/
@Override
public void processMessage() {
VectorAggregation aggregation = new VectorAggregation(rowIndex, (short) voidConfiguration.getNumberOfShards(), shardIndex, storage.getArray(key).getRow(rowIndex).dup());
aggregation.setOriginatorId(this.getOriginatorId());
transport.sendMessage(aggregation);
}
use of org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation in project nd4j by deeplearning4j.
the class VectorRequestMessage method processMessage.
/**
* This message is possible to get only as Shard
*/
@Override
public void processMessage() {
VectorAggregation aggregation = new VectorAggregation(rowIndex, (short) voidConfiguration.getNumberOfShards(), getShardIndex(), storage.getArray(key).getRow(rowIndex).dup());
aggregation.setOriginatorId(this.getOriginatorId());
clipboard.pin(aggregation);
DistributedVectorMessage dvm = new DistributedVectorMessage(key, rowIndex);
dvm.setOriginatorId(this.originatorId);
if (voidConfiguration.getNumberOfShards() > 1)
transport.sendMessageToAllShards(dvm);
else {
aggregation.extractContext(this);
aggregation.processMessage();
}
}
Aggregations