Search in sources :

Example 1 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.

the class CountFunction method call.

@Override
public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
    // since we can't be 100% sure that sequence size is ok itself, or it's not overflow through int limits, we'll recalculate it.
    // anyway we're going to loop through it for elements frequencies
    Counter<Long> localCounter = new Counter<>();
    long seqLen = 0;
    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfigurationBroadcast.getValue().getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    driver = ela.getTrainingDriver();
    //System.out.println("Initializing VoidParameterServer in CountFunction");
    VoidParameterServer.getInstance().init(voidConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    for (T element : sequence.getElements()) {
        if (element == null)
            continue;
        // FIXME: hashcode is bad idea here. we need Long id
        localCounter.incrementCount(element.getStorageId(), 1.0);
        seqLen++;
    }
    // FIXME: we're missing label information here due to shallow vocab mechanics
    if (sequence.getSequenceLabels() != null)
        for (T label : sequence.getSequenceLabels()) {
            localCounter.incrementCount(label.getStorageId(), 1.0);
        }
    accumulator.add(localCounter);
    return Pair.makePair(sequence, seqLen);
}
Also used : Counter(org.deeplearning4j.berkeley.Counter) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport)

Example 2 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.

the class TrainingFunction method call.

@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
    /**
         * Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    /*
         at this moment we should have everything ready for actual initialization
         the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
        */
    Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
    for (T element : sequence.getElements()) {
        // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
        ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
        if (reduced != null)
            mergedSequence.addElement(reduced);
    }
    // do the same with labels, transfer them, if any
    if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
        for (T label : sequence.getSequenceLabels()) {
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
            if (reduced != null)
                mergedSequence.addSequenceLabel(reduced);
        }
    }
    // FIXME: temporary hook
    if (sequence.size() > 0)
        paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
    else
        log.warn("Skipping empty sequence...");
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 3 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method fitSequences.

/**
     * Base training entry point
     *
     * @param corpus
     */
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
    /**
         * Basically all we want for base implementation here is 3 things:
         * a) build vocabulary
         * b) build huffman tree
         * c) do training
         *
         * in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
         * and calling this method for training, instead implementing own routines
         */
    validateConfiguration();
    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (workers > 1) {
        log.info("Repartitioning corpus to {} parts...", workers);
        corpus.repartition(workers);
    }
    if (storageLevel != null)
        corpus.persist(storageLevel);
    final JavaSparkContext sc = new JavaSparkContext(corpus.context());
    // this will have any effect only if wasn't called before, in extension classes
    broadcastEnvironment(sc);
    Counter<Long> finalCounter;
    long numberOfSequences = 0;
    /**
         * Here we s
         */
    if (paramServerConfiguration == null)
        paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
    isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
    Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
    if (isAutoDiscoveryMode) {
        log.info("Trying auto discovery mode...");
        elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
        ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        finalCounter = elementsFreqAccumExtra.value();
        ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
        // getting list of available hosts
        Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
        log.info("availableHosts: {}", availableHosts);
        if (availableHosts.size() > 1) {
            // now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
            NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
            paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
            // backup shards are optional
            if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
                paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
            }
        } else {
            // for single host (aka driver-only, aka spark-local) just run on loopback interface
            paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
            paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
        }
        log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
    } else {
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
        // set up freqs accumulator
        elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
        CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
        // count all sequence elements and their sum
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        // now we grab counter, which contains frequencies for all SequenceElements in corpus
        finalCounter = elementsFreqAccum.value();
    }
    long numberOfElements = (long) finalCounter.totalCount();
    long numberOfUniqueElements = finalCounter.size();
    log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
    /*
         build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
         like index or hashcode. So we could reduce driver memory footprint
         */
    // build huffman tree, and update original RDD with huffman encoding info
    shallowVocabCache = buildShallowVocabCache(finalCounter);
    shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
    // FIXME: probably we need to reconsider this approach
    JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
    vocabRDD.count();
    /**
         * now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
         */
    VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
    VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
    // proceed to training
    // also, training function is the place where we invoke ParameterServer
    TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    if (configuration != null)
        for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
    //corpus.foreach(trainer);
    // we're transferring vectors to ExportContainer
    JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
    // at this particular moment training should be pretty much done, and we're good to go for export
    if (exporter != null)
        exporter.export(exportRdd);
    // unpersist, if we've persisten corpus after all
    if (storageLevel != null)
        corpus.unpersist();
    log.info("Training finish, starting cleanup...");
    VoidParameterServer.getInstance().shutdown();
}
Also used : NetworkOrganizer(org.deeplearning4j.spark.models.sequencevectors.utils.NetworkOrganizer) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Counter(org.deeplearning4j.berkeley.Counter) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Pair(org.deeplearning4j.berkeley.Pair) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) DL4JInvalidConfigException(org.deeplearning4j.exception.DL4JInvalidConfigException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport)

Example 4 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast4.

/**
 * This test checks multiple Clients hammering single Shard
 *
 * @throws Exception
 */
@Test
public void testPerformanceUnicast4() throws Exception {
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
    Transport transport = new RoutedTransport();
    transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
    VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
    parameterServer.setShardIndex((short) 0);
    parameterServer.init(voidConfiguration, transport, new SkipGramTrainer());
    parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
    VoidParameterServer[] clients = new VoidParameterServer[1];
    for (int c = 0; c < clients.length; c++) {
        clients[c] = new VoidParameterServer(NodeRole.CLIENT);
        Transport clientTransport = new RoutedTransport();
        clientTransport.setIpAndPort("127.0.0.1", Integer.valueOf("4872" + c));
        clients[c].init(voidConfiguration, clientTransport, new SkipGramTrainer());
        assertEquals(NodeRole.CLIENT, clients[c].getNodeRole());
    }
    final List<Long> times = new CopyOnWriteArrayList<>();
    log.info("Starting loop...");
    Thread[] threads = new Thread[clients.length];
    for (int t = 0; t < threads.length; t++) {
        final int c = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            AtomicLong sequence = new AtomicLong(0);
            for (int i = 0; i < 500; i++) {
                Frame<SkipGramRequestMessage> frame = new Frame<>(sequence.incrementAndGet());
                for (int f = 0; f < 128; f++) {
                    frame.stackMessage(getSGRM());
                }
                long time1 = System.nanoTime();
                clients[c].execDistributed(frame);
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 50 == 0)
                    log.info("Thread_{} finished {} frames...", c, i);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (Thread thread : threads) thread.join();
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    for (VoidParameterServer client : clients) {
        client.shutdown();
    }
    parameterServer.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) Frame(org.nd4j.parameterserver.distributed.messages.Frame) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Test(org.junit.Test)

Example 5 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast3.

/**
 * This test checks for single Shard scenario, when Shard is also a Client
 *
 * @throws Exception
 */
@Test
public void testPerformanceUnicast3() throws Exception {
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
    Transport transport = new RoutedTransport();
    transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
    VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
    parameterServer.setShardIndex((short) 0);
    parameterServer.init(voidConfiguration, transport, new CbowTrainer());
    parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
    final List<Long> times = new ArrayList<>();
    log.info("Starting loop...");
    for (int i = 0; i < 200; i++) {
        Frame<CbowRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
        for (int f = 0; f < 128; f++) {
            frame.stackMessage(getCRM());
        }
        long time1 = System.nanoTime();
        parameterServer.execDistributed(frame);
        long time2 = System.nanoTime();
        times.add(time2 - time1);
        if (i % 50 == 0)
            log.info("{} frames passed...", i);
    }
    Collections.sort(times);
    log.info("p50: {} us", times.get(times.size() / 2) / 1000);
    parameterServer.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) CbowTrainer(org.nd4j.parameterserver.distributed.training.impl.CbowTrainer) Frame(org.nd4j.parameterserver.distributed.messages.Frame) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage) Test(org.junit.Test)

Aggregations

RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)9 ArrayList (java.util.ArrayList)5 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)5 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)4 Test (org.junit.Test)4 MulticastTransport (org.nd4j.parameterserver.distributed.transport.MulticastTransport)4 Transport (org.nd4j.parameterserver.distributed.transport.Transport)4 List (java.util.List)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Frame (org.nd4j.parameterserver.distributed.messages.Frame)3 SkipGramTrainer (org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer)3 Counter (org.deeplearning4j.berkeley.Counter)2 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)2 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)2 ClientRouter (org.nd4j.parameterserver.distributed.logic.ClientRouter)2 InterleavedRouter (org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 Pair (org.deeplearning4j.berkeley.Pair)1 DL4JInvalidConfigException (org.deeplearning4j.exception.DL4JInvalidConfigException)1