* This method checks for designated role, according to local IP addresses and configuration passed into method
* @param voidConfiguration
* @param localIPs
* @return
protected Pair<NodeRole, String> getRole(@NonNull VoidConfiguration voidConfiguration, @NonNull Collection<String> localIPs) {
NodeRole result = NodeRole.CLIENT;
for (String ip : voidConfiguration.getShardAddresses()) {
String cleansed = ip.replaceAll(":.*", "");
if (localIPs.contains(cleansed))
return Pair.create(NodeRole.SHARD, ip);
if (voidConfiguration.getBackupAddresses() != null)
for (String ip : voidConfiguration.getBackupAddresses()) {
String cleansed = ip.replaceAll(":.*", "");
if (localIPs.contains(cleansed))
return Pair.create(NodeRole.BACKUP, ip);
String sparkIp = null;
if (sparkIp == null && voidConfiguration.getNetworkMask() != null) {
NetworkOrganizer organizer = new NetworkOrganizer(voidConfiguration.getNetworkMask());
sparkIp = organizer.getMatchingAddress();
// last resort here...
if (sparkIp == null)
sparkIp = System.getenv("DL4J_VOID_IP");"Got [{}] as sparkIp", sparkIp);
if (sparkIp == null)
throw new ND4JIllegalStateException("Can't get IP address for UDP communcation");
// local IP from pair is used for shard only, so we don't care
return Pair.create(result, sparkIp + ":" + voidConfiguration.getUnicastPort());
* Simple test for Frame functionality
public void testFrame1() {
final AtomicInteger count = new AtomicInteger(0);
Frame<TrainingMessage> frame = new Frame<>();
for (int i = 0; i < 10; i++) {
frame.stackMessage(new TrainingMessage() {
public byte getCounter() {
return 2;
public void setTargetId(short id) {
public int getRetransmitCount() {
return 0;
public void incrementRetransmitCount() {
public long getFrameId() {
return 0;
public void setFrameId(long frameId) {
public long getOriginatorId() {
return 0;
public void setOriginatorId(long id) {
public short getTargetId() {
return 0;
public long getTaskId() {
return 0;
public int getMessageType() {
return 0;
public byte[] asBytes() {
return new byte[0];
public UnsafeBuffer asUnsafeBuffer() {
return null;
public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) {
// no-op intentionally
public void extractContext(BaseVoidMessage message) {
// no-op intentionally
public void processMessage() {
public boolean isJoinSupported() {
return false;
public void joinMessage(VoidMessage message) {
// no-op
public boolean isBlockingMessage() {
return false;
assertEquals(10, frame.size());
assertEquals(20, count.get());
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
if (voidConfiguration.getTtl() < 1)
throw new ND4JIllegalStateException("For MulticastTransport you should have TTL >= 1, it won't work otherwise");
if (voidConfiguration.getMulticastNetwork() == null || voidConfiguration.getMulticastNetwork().isEmpty())
throw new ND4JIllegalStateException("For MulticastTransport you should provide IP from multicast network available/allowed in your environment, i.e.:");
// shutdown hook
super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
this.voidConfiguration = voidConfiguration;
this.nodeRole = role;
this.clipboard = clipboard;
context = new Aeron.Context();
driver = MediaDriver.launchEmbedded();
aeron = Aeron.connect(context);
this.shardIndex = shardIndex;
multicastChannelUri = "aeron:udp?endpoint=" + voidConfiguration.getMulticastNetwork() + ":" + voidConfiguration.getMulticastPort();
if (voidConfiguration.getMulticastInterface() != null && !voidConfiguration.getMulticastInterface().isEmpty())
multicastChannelUri = multicastChannelUri + "|interface=" + voidConfiguration.getMulticastInterface();
multicastChannelUri = multicastChannelUri + "|ttl=" + voidConfiguration.getTtl();
if (voidConfiguration.getNumberOfShards() < 0)
switch(nodeRole) {
case BACKUP:
case SHARD:
In case of Shard, unicast address for communication is known in advance
if (ip == null) {
ip = localIp;
port = voidConfiguration.getUnicastPort();
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;"Shard unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to receive batches from Clients
subscriptionForShards = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to send completion reports back to Clients
publicationForClients = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 1);
// this channel will be used for communication with other Shards
publicationForShards = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 2);
// this channel will be used to receive messages from other Shards
subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 2);
messageHandlerForShards = new FragmentAssembler((buffer, offset, length, header) -> shardMessageHandler(buffer, offset, length, header));
messageHandlerForClients = new FragmentAssembler(((buffer, offset, length, header) -> internalMessageHandler(buffer, offset, length, header)));
case CLIENT:
ip = localIp;
In case of Client, unicast will be one of shards, picked up with random
// FIXME: we don't want that
// ArrayUtil.getRandomElement(configuration.getShardAddresses());
String rts = voidConfiguration.getShardAddresses().get(0);
String[] split = rts.split(":");
if (split.length == 1) {
ip = rts;
port = voidConfiguration.getUnicastPort();
} else {
ip = split[0];
port = Integer.valueOf(split[1]);
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
// unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + (configuration.getUnicastPort()) ;"Client unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
this channel will be used to send batches to Shards, it's 1:1 channel to one of the Shards
publicationForShards = aeron.addPublication(unicastChannelUri, voidConfiguration.getStreamId());
// this channel will be used to receive completion reports from Shards
subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 1);
messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> clientMessageHandler(buffer, offset, length, header));
log.warn("Unknown role passed: {}", nodeRole);
throw new RuntimeException();
// if that's local spark run - we don't need this
if (voidConfiguration.getNumberOfShards() == 1 && nodeRole == NodeRole.SHARD)
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
this.nodeRole = role;
this.clipboard = clipboard;
this.voidConfiguration = voidConfiguration;
this.shardIndex = shardIndex;
this.messages = new LinkedBlockingQueue<>();
// shutdown hook
super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
setProperty("aeron.client.liveness.timeout", "30000000000");
context = new Aeron.Context().publicationConnectionTimeout(30000000000L).driverTimeoutMs(30000).keepAliveInterval(100000000);
driver = MediaDriver.launchEmbedded();
aeron = Aeron.connect(context);
if (router == null)
router = new InterleavedRouter();
// we skip IPs assign process if they were defined externally
if (port == 0) {
ip = localIp;
port = localPort;
unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
subscriptionForClients = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
// clean shut down
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> jointMessageHandler(buffer, offset, length, header));
Now, regardless of current role,
we set up publication channel to each shard
String shardChannelUri = null;
String remoteIp = null;
int remotePort = 0;
for (String ip : voidConfiguration.getShardAddresses()) {
if (ip.contains(":")) {
shardChannelUri = "aeron:udp?endpoint=" + ip;
String[] split = ip.split(":");
remoteIp = split[0];
remotePort = Integer.valueOf(split[1]);
} else {
shardChannelUri = "aeron:udp?endpoint=" + ip + ":" + voidConfiguration.getUnicastPort();
remoteIp = ip;
remotePort = voidConfiguration.getUnicastPort();
Publication publication = aeron.addPublication(shardChannelUri, voidConfiguration.getStreamId());
RemoteConnection connection = RemoteConnection.builder().ip(remoteIp).port(remotePort).publication(publication).locker(new Object()).build();
if (nodeRole == NodeRole.SHARD)"Initialized as [{}]; ShardIndex: [{}]; Own endpoint: [{}]", nodeRole, shardIndex, unicastChannelUri);
else"Initialized as [{}]; Own endpoint: [{}]", nodeRole, unicastChannelUri);
switch(nodeRole) {
case MASTER:
case BACKUP:
case SHARD:
For unicast transport we want to have interconnects between all shards first of all, because we know their IPs in advance.
But due to design requirements, clients have the same first step, so it's kinda shared for all states :)
Next step is connections setup for backup nodes.
TODO: to be implemented
addClient(ip, port);
case CLIENT:
For Clients on unicast transport, we either set up connection to single Shard, or to multiple shards
But since this code is shared - we don't do anything here
throw new ND4JIllegalStateException("Unknown NodeRole being passed: " + nodeRole);
router.init(voidConfiguration, this);
this.originatorId = HashUtil.getLongHash(this.getIp() + ":" + this.getPort());
public void sendMessageToAllClients(VoidMessage message, Long... exclusions) {
if (nodeRole != NodeRole.SHARD)
throw new ND4JIllegalStateException("Only SHARD allowed to send messages to all Clients");
final DirectBuffer buffer = message.asUnsafeBuffer();
// no need to search for matches above number of then exclusions
final AtomicInteger cnt = new AtomicInteger(0);
// final StringBuilder builder = new StringBuilder("Got message from: [").append(message.getOriginatorId()).append("]; Resend: {");
clients.values().parallelStream().filter(rc -> {
// do not send message back to yourself :)
if (rc.getLongHash() == this.originatorId || rc.getLongHash() == 0) {
// builder.append(", SKIP: ").append(rc.getLongHash());
return false;
// we skip exclusions here
if (exclusions != null && cnt.get() < exclusions.length) {
for (Long exclude : exclusions) if (exclude.longValue() == rc.getLongHash()) {
// builder.append(", SKIP: ").append(rc.getLongHash());
return false;
// builder.append(", PASS: ").append(rc.getLongHash());
return true;
}).forEach((rc) -> {
//"Sending message to {}", rc.getLongHash());
RetransmissionHandler.TransmissionStatus res;
long retr = 0;
boolean delivered = false;
while (!delivered) {
// still stupid. maybe use real reentrant lock here?
synchronized ( {
res = RetransmissionHandler.getTransmissionStatus(rc.getPublication().offer(buffer));
switch(res) {
if (!rc.getActivated().get()) {
if (retr > 20)
throw new ND4JIllegalStateException("Can't connect to Shard: [" + rc.getPublication().channel() + "]");
try {
// Thread.sleep(voidConfiguration.getRetransmitTimeout());
LockSupport.parkNanos(voidConfiguration.getRetransmitTimeout() * 1000000);
} catch (Exception e) {
throw new RuntimeException(e);
} else {
throw new ND4JIllegalStateException("Shards reassignment is to be implemented yet");
try {
// Thread.sleep(voidConfiguration.getRetransmitTimeout());
LockSupport.parkNanos(voidConfiguration.getRetransmitTimeout() * 1000000);
} catch (Exception e) {
throw new RuntimeException(e);
delivered = true;
// s"RESULT: {}", builder.toString());