use of org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer in project nd4j by deeplearning4j.
the class VoidParameterServerTest method testNodeRole3.
@Test
public void testNodeRole3() throws Exception {
final VoidConfiguration conf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(10).shardAddresses(badIPs).backupAddresses(badIPs).multicastNetwork("224.0.1.1").ttl(4).build();
VoidParameterServer node = new VoidParameterServer();
node.init(conf, transport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, node.getNodeRole());
node.shutdown();
}
use of org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer in project nd4j by deeplearning4j.
the class VoidParameterServerTest method testNodeRole2.
@Test
public void testNodeRole2() throws Exception {
final VoidConfiguration conf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(10).shardAddresses(badIPs).backupAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
VoidParameterServer node = new VoidParameterServer();
node.init(conf, transport, new SkipGramTrainer());
assertEquals(NodeRole.BACKUP, node.getNodeRole());
node.shutdown();
}
use of org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer in project nd4j by deeplearning4j.
the class DistributedSgDotMessage method processMessage.
/**
* This method calculates dot of gives rows
*/
@Override
public void processMessage() {
// this only picks up new training round
// log.info("sI_{} Processing DistributedSgDotMessage taskId: {}", transport.getShardIndex(), getTaskId());
SkipGramRequestMessage sgrm = new SkipGramRequestMessage(w1, w2, rowsB, codes, negSamples, alpha, 119);
if (negSamples > 0) {
// unfortunately we have to get copy of negSamples here
int[] negatives = Arrays.copyOfRange(rowsB, codes.length, rowsB.length);
sgrm.setNegatives(negatives);
}
sgrm.setTaskId(this.taskId);
sgrm.setOriginatorId(this.getOriginatorId());
// FIXME: get rid of THAT
SkipGramTrainer sgt = (SkipGramTrainer) trainer;
sgt.pickTraining(sgrm);
// TODO: make this thing a single op, even specialOp is ok
// we calculate dot for all involved rows
int resultLength = codes.length + (negSamples > 0 ? (negSamples + 1) : 0);
INDArray result = Nd4j.createUninitialized(resultLength, 1);
int e = 0;
for (; e < codes.length; e++) {
double dot = Nd4j.getBlasWrapper().dot(storage.getArray(WordVectorStorage.SYN_0).getRow(w2), storage.getArray(WordVectorStorage.SYN_1).getRow(rowsB[e]));
result.putScalar(e, dot);
}
// negSampling round
for (; e < resultLength; e++) {
double dot = Nd4j.getBlasWrapper().dot(storage.getArray(WordVectorStorage.SYN_0).getRow(w2), storage.getArray(WordVectorStorage.SYN_1_NEGATIVE).getRow(rowsB[e]));
result.putScalar(e, dot);
}
if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
// just local bypass
DotAggregation dot = new DotAggregation(taskId, (short) 1, shardIndex, result);
dot.setTargetId((short) -1);
dot.setOriginatorId(getOriginatorId());
transport.putMessage(dot);
} else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
// send this message to everyone
DotAggregation dot = new DotAggregation(taskId, (short) voidConfiguration.getNumberOfShards(), shardIndex, result);
dot.setTargetId((short) -1);
dot.setOriginatorId(getOriginatorId());
transport.sendMessage(dot);
}
}
use of org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast4.
/**
* This test checks multiple Clients hammering single Shard
*
* @throws Exception
*/
@Test
public void testPerformanceUnicast4() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
parameterServer.setShardIndex((short) 0);
parameterServer.init(voidConfiguration, transport, new SkipGramTrainer());
parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
VoidParameterServer[] clients = new VoidParameterServer[1];
for (int c = 0; c < clients.length; c++) {
clients[c] = new VoidParameterServer(NodeRole.CLIENT);
Transport clientTransport = new RoutedTransport();
clientTransport.setIpAndPort("127.0.0.1", Integer.valueOf("4872" + c));
clients[c].init(voidConfiguration, clientTransport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clients[c].getNodeRole());
}
final List<Long> times = new CopyOnWriteArrayList<>();
log.info("Starting loop...");
Thread[] threads = new Thread[clients.length];
for (int t = 0; t < threads.length; t++) {
final int c = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
AtomicLong sequence = new AtomicLong(0);
for (int i = 0; i < 500; i++) {
Frame<SkipGramRequestMessage> frame = new Frame<>(sequence.incrementAndGet());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getSGRM());
}
long time1 = System.nanoTime();
clients[c].execDistributed(frame);
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 50 == 0)
log.info("Thread_{} finished {} frames...", c, i);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (Thread thread : threads) thread.join();
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
for (VoidParameterServer client : clients) {
client.shutdown();
}
parameterServer.shutdown();
}
use of org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast1.
/**
* This is one of the MOST IMPORTANT tests
*/
@Test
public void testPerformanceUnicast1() {
List<String> list = new ArrayList<>();
for (int t = 0; t < 1; t++) {
list.add("127.0.0.1:3838" + t);
}
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
VoidParameterServer[] shards = new VoidParameterServer[list.size()];
for (int t = 0; t < shards.length; t++) {
shards[t] = new VoidParameterServer(NodeRole.SHARD);
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
shards[t].setShardIndex((short) t);
shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
}
VoidParameterServer clientNode = new VoidParameterServer(NodeRole.CLIENT);
RoutedTransport transport = new RoutedTransport();
ClientRouter router = new InterleavedRouter(0);
transport.setRouter(router);
transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
router.init(voidConfiguration, transport);
clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
final List<Long> times = new CopyOnWriteArrayList<>();
// at this point, everything should be started, time for tests
clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
log.info("Initialization finished, going to tests...");
Thread[] threads = new Thread[4];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 200; i++) {
long time1 = System.nanoTime();
INDArray array = clientNode.getVector(RandomUtils.nextInt(start, end));
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 100 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
// shutdown everything
for (VoidParameterServer shard : shards) {
shard.getTransport().shutdown();
}
clientNode.getTransport().shutdown();
}
Aggregations