use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.
the class NetworkOrganizerTest method testSelectionSingleBox1.
@Test
public void testSelectionSingleBox1() throws Exception {
List<NetworkInformation> collection = new ArrayList<>();
NetworkInformation information = new NetworkInformation();
information.addIpAddress("192.168.21.12");
information.addIpAddress("10.0.27.19");
collection.add(information);
NetworkOrganizer organizer = new NetworkOrganizer(collection, "192.168.0.0/16");
List<String> shards = organizer.getSubset(10);
assertEquals(1, shards.size());
}
use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.
the class NetworkOrganizerTest method testSelectionSingleBox2.
@Test
public void testSelectionSingleBox2() throws Exception {
List<NetworkInformation> collection = new ArrayList<>();
NetworkInformation information = new NetworkInformation();
information.addIpAddress("192.168.72.12");
information.addIpAddress("10.2.88.19");
collection.add(information);
NetworkOrganizer organizer = new NetworkOrganizer(collection);
List<String> shards = organizer.getSubset(10);
assertEquals(1, shards.size());
}
use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method fitSequences.
/**
* Base training entry point
*
* @param corpus
*/
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
/**
* Basically all we want for base implementation here is 3 things:
* a) build vocabulary
* b) build huffman tree
* c) do training
*
* in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
* and calling this method for training, instead implementing own routines
*/
validateConfiguration();
if (ela == null) {
try {
ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (workers > 1) {
log.info("Repartitioning corpus to {} parts...", workers);
corpus.repartition(workers);
}
if (storageLevel != null)
corpus.persist(storageLevel);
final JavaSparkContext sc = new JavaSparkContext(corpus.context());
// this will have any effect only if wasn't called before, in extension classes
broadcastEnvironment(sc);
Counter<Long> finalCounter;
long numberOfSequences = 0;
/**
* Here we s
*/
if (paramServerConfiguration == null)
paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
if (isAutoDiscoveryMode) {
log.info("Trying auto discovery mode...");
elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
finalCounter = elementsFreqAccumExtra.value();
ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
// getting list of available hosts
Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
log.info("availableHosts: {}", availableHosts);
if (availableHosts.size() > 1) {
// now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
// backup shards are optional
if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
}
} else {
// for single host (aka driver-only, aka spark-local) just run on loopback interface
paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
}
log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
} else {
// update ps configuration with real values where required
paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
// set up freqs accumulator
elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
// count all sequence elements and their sum
JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
// just to trigger map function, since we need huffman tree before proceeding
numberOfSequences = countedCorpus.count();
// now we grab counter, which contains frequencies for all SequenceElements in corpus
finalCounter = elementsFreqAccum.value();
}
long numberOfElements = (long) finalCounter.totalCount();
long numberOfUniqueElements = finalCounter.size();
log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
/*
build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
like index or hashcode. So we could reduce driver memory footprint
*/
// build huffman tree, and update original RDD with huffman encoding info
shallowVocabCache = buildShallowVocabCache(finalCounter);
shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
// FIXME: probably we need to reconsider this approach
JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
vocabRDD.count();
/**
* now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
*/
VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
// proceed to training
// also, training function is the place where we invoke ParameterServer
TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
if (configuration != null)
for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
//corpus.foreach(trainer);
// we're transferring vectors to ExportContainer
JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
// at this particular moment training should be pretty much done, and we're good to go for export
if (exporter != null)
exporter.export(exportRdd);
// unpersist, if we've persisten corpus after all
if (storageLevel != null)
corpus.unpersist();
log.info("Training finish, starting cleanup...");
VoidParameterServer.getInstance().shutdown();
}
use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.
the class NetworkOrganizerTest method testSelectionWithoutMaskB1.
/**
* In this test we'll check shards selection in "casual" AWS setup
* By default AWS box has only one IP from 172.16.0.0/12 space + local loopback IP, which isn't exposed
*
* @throws Exception
*/
@Test
public void testSelectionWithoutMaskB1() throws Exception {
List<NetworkInformation> collection = new ArrayList<>();
// we imitiate 512 cluster nodes here
for (int i = 0; i < 512; i++) {
NetworkInformation information = new NetworkInformation();
information.addIpAddress(getRandomAwsIp());
collection.add(information);
}
NetworkOrganizer organizer = new NetworkOrganizer(collection);
List<String> shards = organizer.getSubset(10);
assertEquals(10, shards.size());
List<String> backup = organizer.getSubset(10, shards);
assertEquals(10, backup.size());
for (String ip : backup) {
assertNotEquals(null, ip);
assertTrue(ip.startsWith("172."));
assertFalse(shards.contains(ip));
}
}
use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.
the class NetworkOrganizer method getSubset.
/**
* This method returns specified number of IP addresses from original list of addresses, that are NOT listen in primary collection
*
* @param numShards
* @param primary Collection of IP addresses that shouldn't be in result
* @return
*/
public List<String> getSubset(int numShards, Collection<String> primary) {
/**
* If netmask in unset, we'll use manual
*/
if (networkMask == null)
return getIntersections(numShards, primary);
List<String> addresses = new ArrayList<>();
SubnetUtils utils = new SubnetUtils(networkMask);
Collections.shuffle(informationCollection);
for (NetworkInformation information : informationCollection) {
for (String ip : information.getIpAddresses()) {
if (primary != null && primary.contains(ip))
continue;
if (utils.getInfo().isInRange(ip)) {
log.debug("Picked {} as {}", ip, primary == null ? "Shard" : "Backup");
addresses.add(ip);
}
if (addresses.size() >= numShards)
break;
}
if (addresses.size() >= numShards)
break;
}
return addresses;
}
Aggregations