Search in sources :

Example 1 with NetworkInformation

use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.

the class NetworkOrganizerTest method testSelectionSingleBox1.

@Test
public void testSelectionSingleBox1() throws Exception {
    List<NetworkInformation> collection = new ArrayList<>();
    NetworkInformation information = new NetworkInformation();
    information.addIpAddress("192.168.21.12");
    information.addIpAddress("10.0.27.19");
    collection.add(information);
    NetworkOrganizer organizer = new NetworkOrganizer(collection, "192.168.0.0/16");
    List<String> shards = organizer.getSubset(10);
    assertEquals(1, shards.size());
}
Also used : NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) Test(org.junit.Test)

Example 2 with NetworkInformation

use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.

the class NetworkOrganizerTest method testSelectionSingleBox2.

@Test
public void testSelectionSingleBox2() throws Exception {
    List<NetworkInformation> collection = new ArrayList<>();
    NetworkInformation information = new NetworkInformation();
    information.addIpAddress("192.168.72.12");
    information.addIpAddress("10.2.88.19");
    collection.add(information);
    NetworkOrganizer organizer = new NetworkOrganizer(collection);
    List<String> shards = organizer.getSubset(10);
    assertEquals(1, shards.size());
}
Also used : NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) Test(org.junit.Test)

Example 3 with NetworkInformation

use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method fitSequences.

/**
     * Base training entry point
     *
     * @param corpus
     */
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
    /**
         * Basically all we want for base implementation here is 3 things:
         * a) build vocabulary
         * b) build huffman tree
         * c) do training
         *
         * in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
         * and calling this method for training, instead implementing own routines
         */
    validateConfiguration();
    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (workers > 1) {
        log.info("Repartitioning corpus to {} parts...", workers);
        corpus.repartition(workers);
    }
    if (storageLevel != null)
        corpus.persist(storageLevel);
    final JavaSparkContext sc = new JavaSparkContext(corpus.context());
    // this will have any effect only if wasn't called before, in extension classes
    broadcastEnvironment(sc);
    Counter<Long> finalCounter;
    long numberOfSequences = 0;
    /**
         * Here we s
         */
    if (paramServerConfiguration == null)
        paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
    isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
    Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
    if (isAutoDiscoveryMode) {
        log.info("Trying auto discovery mode...");
        elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
        ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        finalCounter = elementsFreqAccumExtra.value();
        ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
        // getting list of available hosts
        Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
        log.info("availableHosts: {}", availableHosts);
        if (availableHosts.size() > 1) {
            // now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
            NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
            paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
            // backup shards are optional
            if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
                paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
            }
        } else {
            // for single host (aka driver-only, aka spark-local) just run on loopback interface
            paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
            paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
        }
        log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
    } else {
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
        // set up freqs accumulator
        elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
        CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
        // count all sequence elements and their sum
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        // now we grab counter, which contains frequencies for all SequenceElements in corpus
        finalCounter = elementsFreqAccum.value();
    }
    long numberOfElements = (long) finalCounter.totalCount();
    long numberOfUniqueElements = finalCounter.size();
    log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
    /*
         build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
         like index or hashcode. So we could reduce driver memory footprint
         */
    // build huffman tree, and update original RDD with huffman encoding info
    shallowVocabCache = buildShallowVocabCache(finalCounter);
    shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
    // FIXME: probably we need to reconsider this approach
    JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
    vocabRDD.count();
    /**
         * now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
         */
    VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
    VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
    // proceed to training
    // also, training function is the place where we invoke ParameterServer
    TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    if (configuration != null)
        for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
    //corpus.foreach(trainer);
    // we're transferring vectors to ExportContainer
    JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
    // at this particular moment training should be pretty much done, and we're good to go for export
    if (exporter != null)
        exporter.export(exportRdd);
    // unpersist, if we've persisten corpus after all
    if (storageLevel != null)
        corpus.unpersist();
    log.info("Training finish, starting cleanup...");
    VoidParameterServer.getInstance().shutdown();
}
Also used : NetworkOrganizer(org.deeplearning4j.spark.models.sequencevectors.utils.NetworkOrganizer) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Counter(org.deeplearning4j.berkeley.Counter) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Pair(org.deeplearning4j.berkeley.Pair) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) DL4JInvalidConfigException(org.deeplearning4j.exception.DL4JInvalidConfigException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport)

Example 4 with NetworkInformation

use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.

the class NetworkOrganizerTest method testSelectionWithoutMaskB1.

/**
     * In this test we'll check shards selection in "casual" AWS setup
     * By default AWS box has only one IP from 172.16.0.0/12 space + local loopback IP, which isn't exposed
     *
     * @throws Exception
     */
@Test
public void testSelectionWithoutMaskB1() throws Exception {
    List<NetworkInformation> collection = new ArrayList<>();
    // we imitiate 512 cluster nodes here
    for (int i = 0; i < 512; i++) {
        NetworkInformation information = new NetworkInformation();
        information.addIpAddress(getRandomAwsIp());
        collection.add(information);
    }
    NetworkOrganizer organizer = new NetworkOrganizer(collection);
    List<String> shards = organizer.getSubset(10);
    assertEquals(10, shards.size());
    List<String> backup = organizer.getSubset(10, shards);
    assertEquals(10, backup.size());
    for (String ip : backup) {
        assertNotEquals(null, ip);
        assertTrue(ip.startsWith("172."));
        assertFalse(shards.contains(ip));
    }
}
Also used : NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) Test(org.junit.Test)

Example 5 with NetworkInformation

use of org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation in project deeplearning4j by deeplearning4j.

the class NetworkOrganizer method getSubset.

/**
     * This method returns specified number of IP addresses from original list of addresses, that are NOT listen in primary collection
     *
     * @param numShards
     * @param primary Collection of IP addresses that shouldn't be in result
     * @return
     */
public List<String> getSubset(int numShards, Collection<String> primary) {
    /**
         * If netmask in unset, we'll use manual
         */
    if (networkMask == null)
        return getIntersections(numShards, primary);
    List<String> addresses = new ArrayList<>();
    SubnetUtils utils = new SubnetUtils(networkMask);
    Collections.shuffle(informationCollection);
    for (NetworkInformation information : informationCollection) {
        for (String ip : information.getIpAddresses()) {
            if (primary != null && primary.contains(ip))
                continue;
            if (utils.getInfo().isInRange(ip)) {
                log.debug("Picked {} as {}", ip, primary == null ? "Shard" : "Backup");
                addresses.add(ip);
            }
            if (addresses.size() >= numShards)
                break;
        }
        if (addresses.size() >= numShards)
            break;
    }
    return addresses;
}
Also used : SubnetUtils(org.apache.commons.net.util.SubnetUtils) NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation)

Aggregations

NetworkInformation (org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation)9 Test (org.junit.Test)6 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 SubnetUtils (org.apache.commons.net.util.SubnetUtils)1 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 Counter (org.deeplearning4j.berkeley.Counter)1 Pair (org.deeplearning4j.berkeley.Pair)1 DL4JInvalidConfigException (org.deeplearning4j.exception.DL4JInvalidConfigException)1 ExportContainer (org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer)1 ExtraCounter (org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter)1 NetworkOrganizer (org.deeplearning4j.spark.models.sequencevectors.utils.NetworkOrganizer)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)1