use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast1.
/**
* This is one of the MOST IMPORTANT tests
*/
@Test
public void testPerformanceUnicast1() {
List<String> list = new ArrayList<>();
for (int t = 0; t < 1; t++) {
list.add("127.0.0.1:3838" + t);
}
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
VoidParameterServer[] shards = new VoidParameterServer[list.size()];
for (int t = 0; t < shards.length; t++) {
shards[t] = new VoidParameterServer(NodeRole.SHARD);
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
shards[t].setShardIndex((short) t);
shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
}
VoidParameterServer clientNode = new VoidParameterServer(NodeRole.CLIENT);
RoutedTransport transport = new RoutedTransport();
ClientRouter router = new InterleavedRouter(0);
transport.setRouter(router);
transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
router.init(voidConfiguration, transport);
clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
final List<Long> times = new CopyOnWriteArrayList<>();
// at this point, everything should be started, time for tests
clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
log.info("Initialization finished, going to tests...");
Thread[] threads = new Thread[4];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 200; i++) {
long time1 = System.nanoTime();
INDArray array = clientNode.getVector(RandomUtils.nextInt(start, end));
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 100 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
// shutdown everything
for (VoidParameterServer shard : shards) {
shard.getTransport().shutdown();
}
clientNode.getTransport().shutdown();
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.
the class PartitionTrainingFunction method call.
@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
/**
* first we initialize
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm != null)
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm != null)
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
// now we roll throw Sequences and prepare/convert/"learn" them
while (sequenceIterator.hasNext()) {
Sequence<T> sequence = sequenceIterator.next();
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
sequences.add(mergedSequence);
if (sequences.size() >= 8) {
trainAllAtOnce(sequences);
sequences.clear();
}
}
if (sequences.size() > 0) {
// finishing training round, to make sure we don't have trails
trainAllAtOnce(sequences);
sequences.clear();
}
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.
the class InterleavedRouterTest method setUp.
@Before
public void setUp() {
configuration = VoidConfiguration.builder().shardAddresses(Arrays.asList("1.2.3.4", "2.3.4.5", "3.4.5.6", "4.5.6.7")).numberOfShards(// we set it manually here
4).build();
transport = new RoutedTransport();
transport.setIpAndPort("8.9.10.11", 87312);
originator = HashUtil.getLongHash(transport.getIp() + ":" + transport.getPort());
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast2.
/**
* This is second super-important test for unicast transport.
* Here we send non-blocking messages
*/
@Test
@Ignore
public void testPerformanceUnicast2() {
List<String> list = new ArrayList<>();
for (int t = 0; t < 5; t++) {
list.add("127.0.0.1:3838" + t);
}
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
VoidParameterServer[] shards = new VoidParameterServer[list.size()];
for (int t = 0; t < shards.length; t++) {
shards[t] = new VoidParameterServer(NodeRole.SHARD);
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
shards[t].setShardIndex((short) t);
shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
}
VoidParameterServer clientNode = new VoidParameterServer();
RoutedTransport transport = new RoutedTransport();
ClientRouter router = new InterleavedRouter(0);
transport.setRouter(router);
transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
router.init(voidConfiguration, transport);
clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
final List<Long> times = new CopyOnWriteArrayList<>();
// at this point, everything should be started, time for tests
clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
log.info("Initialization finished, going to tests...");
Thread[] threads = new Thread[4];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 200; i++) {
Frame<SkipGramRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getSGRM());
}
long time1 = System.nanoTime();
clientNode.execDistributed(frame);
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 100 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
// shutdown everything
for (VoidParameterServer shard : shards) {
shard.getTransport().shutdown();
}
clientNode.getTransport().shutdown();
}
Aggregations