use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.
the class VoidParameterServer method getRole.
/**
* This method checks for designated role, according to local IP addresses and configuration passed into method
*
* @param voidConfiguration
* @param localIPs
* @return
*/
protected Pair<NodeRole, String> getRole(@NonNull VoidConfiguration voidConfiguration, @NonNull Collection<String> localIPs) {
NodeRole result = NodeRole.CLIENT;
for (String ip : voidConfiguration.getShardAddresses()) {
String cleansed = ip.replaceAll(":.*", "");
if (localIPs.contains(cleansed))
return Pair.create(NodeRole.SHARD, ip);
}
if (voidConfiguration.getBackupAddresses() != null)
for (String ip : voidConfiguration.getBackupAddresses()) {
String cleansed = ip.replaceAll(":.*", "");
if (localIPs.contains(cleansed))
return Pair.create(NodeRole.BACKUP, ip);
}
String sparkIp = null;
if (sparkIp == null && voidConfiguration.getNetworkMask() != null) {
NetworkOrganizer organizer = new NetworkOrganizer(voidConfiguration.getNetworkMask());
sparkIp = organizer.getMatchingAddress();
}
// last resort here...
if (sparkIp == null)
sparkIp = System.getenv("DL4J_VOID_IP");
log.info("Got [{}] as sparkIp", sparkIp);
if (sparkIp == null)
throw new ND4JIllegalStateException("Can't get IP address for UDP communcation");
// local IP from pair is used for shard only, so we don't care
return Pair.create(result, sparkIp + ":" + voidConfiguration.getUnicastPort());
}
use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.
the class FrameTest method testFrame1.
/**
* Simple test for Frame functionality
*/
@Test
public void testFrame1() {
final AtomicInteger count = new AtomicInteger(0);
Frame<TrainingMessage> frame = new Frame<>();
for (int i = 0; i < 10; i++) {
frame.stackMessage(new TrainingMessage() {
@Override
public byte getCounter() {
return 2;
}
@Override
public void setTargetId(short id) {
}
@Override
public int getRetransmitCount() {
return 0;
}
@Override
public void incrementRetransmitCount() {
}
@Override
public long getFrameId() {
return 0;
}
@Override
public void setFrameId(long frameId) {
}
@Override
public long getOriginatorId() {
return 0;
}
@Override
public void setOriginatorId(long id) {
}
@Override
public short getTargetId() {
return 0;
}
@Override
public long getTaskId() {
return 0;
}
@Override
public int getMessageType() {
return 0;
}
@Override
public byte[] asBytes() {
return new byte[0];
}
@Override
public UnsafeBuffer asUnsafeBuffer() {
return null;
}
@Override
public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) {
// no-op intentionally
}
@Override
public void extractContext(BaseVoidMessage message) {
// no-op intentionally
}
@Override
public void processMessage() {
count.incrementAndGet();
}
@Override
public boolean isJoinSupported() {
return false;
}
@Override
public void joinMessage(VoidMessage message) {
// no-op
}
@Override
public boolean isBlockingMessage() {
return false;
}
});
}
assertEquals(10, frame.size());
frame.processMessage();
assertEquals(20, count.get());
}
use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.
the class MulticastTransport method init.
@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
if (voidConfiguration.getTtl() < 1)
throw new ND4JIllegalStateException("For MulticastTransport you should have TTL >= 1, it won't work otherwise");
if (voidConfiguration.getMulticastNetwork() == null || voidConfiguration.getMulticastNetwork().isEmpty())
throw new ND4JIllegalStateException("For MulticastTransport you should provide IP from multicast network available/allowed in your environment, i.e.: 224.0.1.1");
// shutdown hook
super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
this.voidConfiguration = voidConfiguration;
this.nodeRole = role;
this.clipboard = clipboard;
context = new Aeron.Context();
driver = MediaDriver.launchEmbedded();
context.aeronDirectoryName(driver.aeronDirectoryName());
aeron = Aeron.connect(context);
this.shardIndex = shardIndex;
multicastChannelUri = "aeron:udp?endpoint=" + voidConfiguration.getMulticastNetwork() + ":" + voidConfiguration.getMulticastPort();
if (voidConfiguration.getMulticastInterface() != null && !voidConfiguration.getMulticastInterface().isEmpty())
multicastChannelUri = multicastChannelUri + "|interface=" + voidConfiguration.getMulticastInterface();
multicastChannelUri = multicastChannelUri + "|ttl=" + voidConfiguration.getTtl();
if (voidConfiguration.getNumberOfShards() < 0)
voidConfiguration.setNumberOfShards(voidConfiguration.getShardAddresses().size());
switch(nodeRole) {
case BACKUP:
case SHARD:
/*
In case of Shard, unicast address for communication is known in advance
*/
if (ip == null) {
ip = localIp;
port = voidConfiguration.getUnicastPort();
}
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
log.info("Shard unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to receive batches from Clients
subscriptionForShards = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to send completion reports back to Clients
publicationForClients = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 1);
// this channel will be used for communication with other Shards
publicationForShards = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 2);
// this channel will be used to receive messages from other Shards
subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 2);
messageHandlerForShards = new FragmentAssembler((buffer, offset, length, header) -> shardMessageHandler(buffer, offset, length, header));
messageHandlerForClients = new FragmentAssembler(((buffer, offset, length, header) -> internalMessageHandler(buffer, offset, length, header)));
break;
case CLIENT:
ip = localIp;
/*
In case of Client, unicast will be one of shards, picked up with random
*/
// FIXME: we don't want that
// ArrayUtil.getRandomElement(configuration.getShardAddresses());
String rts = voidConfiguration.getShardAddresses().get(0);
String[] split = rts.split(":");
if (split.length == 1) {
ip = rts;
port = voidConfiguration.getUnicastPort();
} else {
ip = split[0];
port = Integer.valueOf(split[1]);
}
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
// unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + (configuration.getUnicastPort()) ;
log.info("Client unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
/*
this channel will be used to send batches to Shards, it's 1:1 channel to one of the Shards
*/
publicationForShards = aeron.addPublication(unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to receive completion reports from Shards
subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 1);
messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> clientMessageHandler(buffer, offset, length, header));
break;
default:
log.warn("Unknown role passed: {}", nodeRole);
throw new RuntimeException();
}
// if that's local spark run - we don't need this
if (voidConfiguration.getNumberOfShards() == 1 && nodeRole == NodeRole.SHARD)
shutdownSilent();
}
use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.
the class RoutedTransport method init.
@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
this.nodeRole = role;
this.clipboard = clipboard;
this.voidConfiguration = voidConfiguration;
this.shardIndex = shardIndex;
this.messages = new LinkedBlockingQueue<>();
// shutdown hook
super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
setProperty("aeron.client.liveness.timeout", "30000000000");
context = new Aeron.Context().publicationConnectionTimeout(30000000000L).driverTimeoutMs(30000).keepAliveInterval(100000000);
driver = MediaDriver.launchEmbedded();
context.aeronDirectoryName(driver.aeronDirectoryName());
aeron = Aeron.connect(context);
if (router == null)
router = new InterleavedRouter();
// we skip IPs assign process if they were defined externally
if (port == 0) {
ip = localIp;
port = localPort;
}
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
subscriptionForClients = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
// clean shut down
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
CloseHelper.quietClose(aeron);
CloseHelper.quietClose(driver);
CloseHelper.quietClose(context);
CloseHelper.quietClose(subscriptionForClients);
}));
messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> jointMessageHandler(buffer, offset, length, header));
/*
Now, regardless of current role,
we set up publication channel to each shard
*/
String shardChannelUri = null;
String remoteIp = null;
int remotePort = 0;
for (String ip : voidConfiguration.getShardAddresses()) {
if (ip.contains(":")) {
shardChannelUri = "aeron:udp?endpoint=" + ip;
String[] split = ip.split(":");
remoteIp = split[0];
remotePort = Integer.valueOf(split[1]);
} else {
shardChannelUri = "aeron:udp?endpoint=" + ip + ":" + voidConfiguration.getUnicastPort();
remoteIp = ip;
remotePort = voidConfiguration.getUnicastPort();
}
Publication publication = aeron.addPublication(shardChannelUri, voidConfiguration.getStreamId());
RemoteConnection connection = RemoteConnection.builder().ip(remoteIp).port(remotePort).publication(publication).locker(new Object()).build();
shards.add(connection);
}
if (nodeRole == NodeRole.SHARD)
log.info("Initialized as [{}]; ShardIndex: [{}]; Own endpoint: [{}]", nodeRole, shardIndex, unicastChannelUri);
else
log.info("Initialized as [{}]; Own endpoint: [{}]", nodeRole, unicastChannelUri);
switch(nodeRole) {
case MASTER:
case BACKUP:
{
}
case SHARD:
{
/*
For unicast transport we want to have interconnects between all shards first of all, because we know their IPs in advance.
But due to design requirements, clients have the same first step, so it's kinda shared for all states :)
*/
/*
Next step is connections setup for backup nodes.
TODO: to be implemented
*/
addClient(ip, port);
}
break;
case CLIENT:
{
/*
For Clients on unicast transport, we either set up connection to single Shard, or to multiple shards
But since this code is shared - we don't do anything here
*/
}
break;
default:
throw new ND4JIllegalStateException("Unknown NodeRole being passed: " + nodeRole);
}
router.init(voidConfiguration, this);
this.originatorId = HashUtil.getLongHash(this.getIp() + ":" + this.getPort());
}
use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.
the class RoutedTransport method sendMessageToAllClients.
@Override
public void sendMessageToAllClients(VoidMessage message, Long... exclusions) {
if (nodeRole != NodeRole.SHARD)
throw new ND4JIllegalStateException("Only SHARD allowed to send messages to all Clients");
final DirectBuffer buffer = message.asUnsafeBuffer();
// no need to search for matches above number of then exclusions
final AtomicInteger cnt = new AtomicInteger(0);
// final StringBuilder builder = new StringBuilder("Got message from: [").append(message.getOriginatorId()).append("]; Resend: {");
clients.values().parallelStream().filter(rc -> {
// do not send message back to yourself :)
if (rc.getLongHash() == this.originatorId || rc.getLongHash() == 0) {
// builder.append(", SKIP: ").append(rc.getLongHash());
return false;
}
// we skip exclusions here
if (exclusions != null && cnt.get() < exclusions.length) {
for (Long exclude : exclusions) if (exclude.longValue() == rc.getLongHash()) {
cnt.incrementAndGet();
// builder.append(", SKIP: ").append(rc.getLongHash());
return false;
}
}
// builder.append(", PASS: ").append(rc.getLongHash());
return true;
}).forEach((rc) -> {
// log.info("Sending message to {}", rc.getLongHash());
RetransmissionHandler.TransmissionStatus res;
long retr = 0;
boolean delivered = false;
while (!delivered) {
// still stupid. maybe use real reentrant lock here?
synchronized (rc.locker) {
res = RetransmissionHandler.getTransmissionStatus(rc.getPublication().offer(buffer));
}
switch(res) {
case NOT_CONNECTED:
{
if (!rc.getActivated().get()) {
retr++;
if (retr > 20)
throw new ND4JIllegalStateException("Can't connect to Shard: [" + rc.getPublication().channel() + "]");
try {
// Thread.sleep(voidConfiguration.getRetransmitTimeout());
LockSupport.parkNanos(voidConfiguration.getRetransmitTimeout() * 1000000);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
throw new ND4JIllegalStateException("Shards reassignment is to be implemented yet");
}
}
break;
case ADMIN_ACTION:
case BACKPRESSURE:
{
try {
// Thread.sleep(voidConfiguration.getRetransmitTimeout());
LockSupport.parkNanos(voidConfiguration.getRetransmitTimeout() * 1000000);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
break;
case MESSAGE_SENT:
delivered = true;
rc.getActivated().set(true);
break;
}
}
});
// s log.info("RESULT: {}", builder.toString());
}
Aggregations