use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.
the class ClipboardTest method testPin1.
@Test
public void testPin1() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
for (int i = 0; i < 100; i++) {
VectorAggregation aggregation = new VectorAggregation(rng.nextLong(), (short) 100, (short) i, Nd4j.create(5));
clipboard.pin(aggregation);
}
assertEquals(false, clipboard.hasCandidates());
assertEquals(0, clipboard.getNumberOfCompleteStacks());
assertEquals(100, clipboard.getNumberOfPinnedStacks());
}
use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.
the class ClipboardTest method testPin2.
@Test
public void testPin2() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
Long validId = 123L;
short shardIdx = 0;
for (int i = 0; i < 300; i++) {
VectorAggregation aggregation = new VectorAggregation(rng.nextLong(), (short) 100, (short) 1, Nd4j.create(5));
// imitating valid
if (i % 2 == 0 && shardIdx < 100) {
aggregation.setTaskId(validId);
aggregation.setShardIndex(shardIdx++);
}
clipboard.pin(aggregation);
}
VoidAggregation aggregation = clipboard.getStackFromClipboard(0L, validId);
assertNotEquals(null, aggregation);
assertEquals(0, aggregation.getMissingChunks());
assertEquals(true, clipboard.hasCandidates());
assertEquals(1, clipboard.getNumberOfCompleteStacks());
}
use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.
the class ClipboardTest method testPin3.
/**
* This test checks how clipboard handles singular aggregations
* @throws Exception
*/
@Test
public void testPin3() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
Long validId = 123L;
InitializationAggregation aggregation = new InitializationAggregation(1, 0);
clipboard.pin(aggregation);
assertTrue(clipboard.isTracking(0L, aggregation.getTaskId()));
assertTrue(clipboard.isReady(0L, aggregation.getTaskId()));
}
use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.
the class FrameTest method testFrame1.
/**
* Simple test for Frame functionality
*/
@Test
public void testFrame1() {
final AtomicInteger count = new AtomicInteger(0);
Frame<TrainingMessage> frame = new Frame<>();
for (int i = 0; i < 10; i++) {
frame.stackMessage(new TrainingMessage() {
@Override
public byte getCounter() {
return 2;
}
@Override
public void setTargetId(short id) {
}
@Override
public int getRetransmitCount() {
return 0;
}
@Override
public void incrementRetransmitCount() {
}
@Override
public long getFrameId() {
return 0;
}
@Override
public void setFrameId(long frameId) {
}
@Override
public long getOriginatorId() {
return 0;
}
@Override
public void setOriginatorId(long id) {
}
@Override
public short getTargetId() {
return 0;
}
@Override
public long getTaskId() {
return 0;
}
@Override
public int getMessageType() {
return 0;
}
@Override
public byte[] asBytes() {
return new byte[0];
}
@Override
public UnsafeBuffer asUnsafeBuffer() {
return null;
}
@Override
public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) {
// no-op intentionally
}
@Override
public void extractContext(BaseVoidMessage message) {
// no-op intentionally
}
@Override
public void processMessage() {
count.incrementAndGet();
}
@Override
public boolean isJoinSupported() {
return false;
}
@Override
public void joinMessage(VoidMessage message) {
// no-op
}
@Override
public boolean isBlockingMessage() {
return false;
}
});
}
assertEquals(10, frame.size());
frame.processMessage();
assertEquals(20, count.get());
}
use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.
the class MulticastTransport method init.
@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
if (voidConfiguration.getTtl() < 1)
throw new ND4JIllegalStateException("For MulticastTransport you should have TTL >= 1, it won't work otherwise");
if (voidConfiguration.getMulticastNetwork() == null || voidConfiguration.getMulticastNetwork().isEmpty())
throw new ND4JIllegalStateException("For MulticastTransport you should provide IP from multicast network available/allowed in your environment, i.e.: 224.0.1.1");
// shutdown hook
super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
this.voidConfiguration = voidConfiguration;
this.nodeRole = role;
this.clipboard = clipboard;
context = new Aeron.Context();
driver = MediaDriver.launchEmbedded();
context.aeronDirectoryName(driver.aeronDirectoryName());
aeron = Aeron.connect(context);
this.shardIndex = shardIndex;
multicastChannelUri = "aeron:udp?endpoint=" + voidConfiguration.getMulticastNetwork() + ":" + voidConfiguration.getMulticastPort();
if (voidConfiguration.getMulticastInterface() != null && !voidConfiguration.getMulticastInterface().isEmpty())
multicastChannelUri = multicastChannelUri + "|interface=" + voidConfiguration.getMulticastInterface();
multicastChannelUri = multicastChannelUri + "|ttl=" + voidConfiguration.getTtl();
if (voidConfiguration.getNumberOfShards() < 0)
voidConfiguration.setNumberOfShards(voidConfiguration.getShardAddresses().size());
switch(nodeRole) {
case BACKUP:
case SHARD:
/*
In case of Shard, unicast address for communication is known in advance
*/
if (ip == null) {
ip = localIp;
port = voidConfiguration.getUnicastPort();
}
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
log.info("Shard unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to receive batches from Clients
subscriptionForShards = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to send completion reports back to Clients
publicationForClients = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 1);
// this channel will be used for communication with other Shards
publicationForShards = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 2);
// this channel will be used to receive messages from other Shards
subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 2);
messageHandlerForShards = new FragmentAssembler((buffer, offset, length, header) -> shardMessageHandler(buffer, offset, length, header));
messageHandlerForClients = new FragmentAssembler(((buffer, offset, length, header) -> internalMessageHandler(buffer, offset, length, header)));
break;
case CLIENT:
ip = localIp;
/*
In case of Client, unicast will be one of shards, picked up with random
*/
// FIXME: we don't want that
// ArrayUtil.getRandomElement(configuration.getShardAddresses());
String rts = voidConfiguration.getShardAddresses().get(0);
String[] split = rts.split(":");
if (split.length == 1) {
ip = rts;
port = voidConfiguration.getUnicastPort();
} else {
ip = split[0];
port = Integer.valueOf(split[1]);
}
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
// unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + (configuration.getUnicastPort()) ;
log.info("Client unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
/*
this channel will be used to send batches to Shards, it's 1:1 channel to one of the Shards
*/
publicationForShards = aeron.addPublication(unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to receive completion reports from Shards
subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 1);
messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> clientMessageHandler(buffer, offset, length, header));
break;
default:
log.warn("Unknown role passed: {}", nodeRole);
throw new RuntimeException();
}
// if that's local spark run - we don't need this
if (voidConfiguration.getNumberOfShards() == 1 && nodeRole == NodeRole.SHARD)
shutdownSilent();
}
Aggregations