use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.
the class CountFunction method call.
@Override
public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
// since we can't be 100% sure that sequence size is ok itself, or it's not overflow through int limits, we'll recalculate it.
// anyway we're going to loop through it for elements frequencies
Counter<Long> localCounter = new Counter<>();
long seqLen = 0;
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = ela.getTrainingDriver();
//System.out.println("Initializing VoidParameterServer in CountFunction");
VoidParameterServer.getInstance().init(voidConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
for (T element : sequence.getElements()) {
if (element == null)
continue;
// FIXME: hashcode is bad idea here. we need Long id
localCounter.incrementCount(element.getStorageId(), 1.0);
seqLen++;
}
// FIXME: we're missing label information here due to shallow vocab mechanics
if (sequence.getSequenceLabels() != null)
for (T label : sequence.getSequenceLabels()) {
localCounter.incrementCount(label.getStorageId(), 1.0);
}
accumulator.add(localCounter);
return Pair.makePair(sequence, seqLen);
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.
the class TrainingFunction method call.
@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
/**
* Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
/*
at this moment we should have everything ready for actual initialization
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
*/
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
// FIXME: temporary hook
if (sequence.size() > 0)
paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
else
log.warn("Skipping empty sequence...");
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method fitSequences.
/**
* Base training entry point
*
* @param corpus
*/
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
/**
* Basically all we want for base implementation here is 3 things:
* a) build vocabulary
* b) build huffman tree
* c) do training
*
* in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
* and calling this method for training, instead implementing own routines
*/
validateConfiguration();
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (workers > 1) {
log.info("Repartitioning corpus to {} parts...", workers);
corpus.repartition(workers);
}
if (storageLevel != null)
corpus.persist(storageLevel);
final JavaSparkContext sc = new JavaSparkContext(corpus.context());
// this will have any effect only if wasn't called before, in extension classes
broadcastEnvironment(sc);
Counter<Long> finalCounter;
long numberOfSequences = 0;
/**
* Here we s
*/
if (paramServerConfiguration == null)
paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
if (isAutoDiscoveryMode) {
log.info("Trying auto discovery mode...");
elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
finalCounter = elementsFreqAccumExtra.value();
ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
// getting list of available hosts
Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
log.info("availableHosts: {}", availableHosts);
if (availableHosts.size() > 1) {
// now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
// backup shards are optional
if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
}
} else {
// for single host (aka driver-only, aka spark-local) just run on loopback interface
paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
}
log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
} else {
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
// set up freqs accumulator
elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
// count all sequence elements and their sum
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
// now we grab counter, which contains frequencies for all SequenceElements in corpus
finalCounter = elementsFreqAccum.value();
}
long numberOfElements = (long) finalCounter.totalCount();
long numberOfUniqueElements = finalCounter.size();
log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
/*
build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
like index or hashcode. So we could reduce driver memory footprint
*/
// build huffman tree, and update original RDD with huffman encoding info
shallowVocabCache = buildShallowVocabCache(finalCounter);
shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
// FIXME: probably we need to reconsider this approach
JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
vocabRDD.count();
/**
* now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
*/
VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
// proceed to training
// also, training function is the place where we invoke ParameterServer
TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
if (configuration != null)
for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
//corpus.foreach(trainer);
// we're transferring vectors to ExportContainer
JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
// at this particular moment training should be pretty much done, and we're good to go for export
if (exporter != null)
exporter.export(exportRdd);
// unpersist, if we've persisten corpus after all
if (storageLevel != null)
corpus.unpersist();
log.info("Training finish, starting cleanup...");
VoidParameterServer.getInstance().shutdown();
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast4.
/**
* This test checks multiple Clients hammering single Shard
*
* @throws Exception
*/
@Test
public void testPerformanceUnicast4() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
parameterServer.setShardIndex((short) 0);
parameterServer.init(voidConfiguration, transport, new SkipGramTrainer());
parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
VoidParameterServer[] clients = new VoidParameterServer[1];
for (int c = 0; c < clients.length; c++) {
clients[c] = new VoidParameterServer(NodeRole.CLIENT);
Transport clientTransport = new RoutedTransport();
clientTransport.setIpAndPort("127.0.0.1", Integer.valueOf("4872" + c));
clients[c].init(voidConfiguration, clientTransport, new SkipGramTrainer());
assertEquals(NodeRole.CLIENT, clients[c].getNodeRole());
}
final List<Long> times = new CopyOnWriteArrayList<>();
log.info("Starting loop...");
Thread[] threads = new Thread[clients.length];
for (int t = 0; t < threads.length; t++) {
final int c = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
AtomicLong sequence = new AtomicLong(0);
for (int i = 0; i < 500; i++) {
Frame<SkipGramRequestMessage> frame = new Frame<>(sequence.incrementAndGet());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getSGRM());
}
long time1 = System.nanoTime();
clients[c].execDistributed(frame);
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 50 == 0)
log.info("Thread_{} finished {} frames...", c, i);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (Thread thread : threads) thread.join();
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
for (VoidParameterServer client : clients) {
client.shutdown();
}
parameterServer.shutdown();
}
use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast3.
/**
* This test checks for single Shard scenario, when Shard is also a Client
*
* @throws Exception
*/
@Test
public void testPerformanceUnicast3() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
parameterServer.setShardIndex((short) 0);
parameterServer.init(voidConfiguration, transport, new CbowTrainer());
parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
final List<Long> times = new ArrayList<>();
log.info("Starting loop...");
for (int i = 0; i < 200; i++) {
Frame<CbowRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getCRM());
}
long time1 = System.nanoTime();
parameterServer.execDistributed(frame);
long time2 = System.nanoTime();
times.add(time2 - time1);
if (i % 50 == 0)
log.info("{} frames passed...", i);
}
Collections.sort(times);
log.info("p50: {} us", times.get(times.size() / 2) / 1000);
parameterServer.shutdown();
}
Aggregations