Search in sources :

Example 21 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.

the class Glove method train.

/**
     * Train on the corpus
     * @param rdd the rdd to train
     * @return the vocab and weights
     */
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
    // Each `train()` can use different parameters
    final JavaSparkContext sc = new JavaSparkContext(rdd.context());
    final SparkConf conf = sc.getConf();
    final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
    final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
    final double negative = assignVar(NEGATIVE, conf, Double.class);
    final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
    final int window = assignVar(WINDOW, conf, Integer.class);
    final double alpha = assignVar(ALPHA, conf, Double.class);
    final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
    final int iterations = assignVar(ITERATIONS, conf, Integer.class);
    final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
    final String tokenizer = assignVar(TOKENIZER, conf, String.class);
    final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
    final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
    Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {

        {
            put("numWords", numWords);
            put("nGrams", nGrams);
            put("tokenizer", tokenizer);
            put("tokenPreprocessor", tokenPreprocessor);
            put("removeStop", removeStop);
        }
    };
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count
    Long totalWordCount = pipeline.getTotalWordCount();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
    final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
    final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
    gloveWeightLookupTable.resetWeights();
    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
    log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
    Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
    List<Triple<String, String, Double>> counts = new ArrayList<>();
    while (pair2.hasNext()) {
        Pair<String, String> next = pair2.next();
        if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
            coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
        }
        counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
    }
    log.info("Calculated co occurrences");
    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {

        @Override
        public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
            return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
        }
    });
    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {

        @Override
        public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
            VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
            VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
            return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
        }
    });
    for (int i = 0; i < iterations; i++) {
        JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {

            @Override
            public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
                VocabWord w1 = vocabWordTuple2Tuple2._1();
                VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                INDArray bias = gloveWeightLookupTable.getBias();
                double score = vocabWordTuple2Tuple2._2()._2();
                double xMax = gloveWeightLookupTable.getxMax();
                double maxCount = gloveWeightLookupTable.getMaxCount();
                //w1 * w2 + bias
                double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
                double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
                double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
                if (Double.isNaN(fDiff))
                    fDiff = Nd4j.EPS_THRESHOLD;
                //amount of change
                double gradient = fDiff;
                Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
                Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
                return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
            }
        });
        List<GloveChange> gloveChanges = change.collect();
        double error = 0.0;
        for (GloveChange change2 : gloveChanges) {
            change2.apply(gloveWeightLookupTable);
            error += change2.getError();
        }
        List l = pairsVocab.collect();
        Collections.shuffle(l);
        pairsVocab = sc.parallelizePairs(l);
        log.info("Error at iteration " + i + " was " + error);
    }
    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
Also used : CoOccurrenceCounts(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Triple(org.deeplearning4j.berkeley.Triple) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterMap(org.deeplearning4j.berkeley.CounterMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) CoOccurrenceCalculator(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) SparkConf(org.apache.spark.SparkConf)

Example 22 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.

the class SparkParagraphVectors method fitLabelledDocuments.

/**
     * This method builds ParagraphVectors model, expecting JavaRDD<LabelledDocument>.
     * It can be either non-tokenized documents, or tokenized.
     *
     * @param documentsRdd
     */
public void fitLabelledDocuments(JavaRDD<LabelledDocument> documentsRdd) {
    validateConfiguration();
    broadcastEnvironment(new JavaSparkContext(documentsRdd.context()));
    JavaRDD<Sequence<VocabWord>> sequenceRDD = documentsRdd.map(new DocumentSequenceConvertFunction(configurationBroadcast));
    super.fitSequences(sequenceRDD);
}
Also used : JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) DocumentSequenceConvertFunction(org.deeplearning4j.spark.models.paragraphvectors.functions.DocumentSequenceConvertFunction)

Example 23 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectors method fitSequences.

/**
     * Base training entry point
     *
     * @param corpus
     */
public void fitSequences(JavaRDD<Sequence<T>> corpus) {
    /**
         * Basically all we want for base implementation here is 3 things:
         * a) build vocabulary
         * b) build huffman tree
         * c) do training
         *
         * in this case all classes extending SeqVec, like deepwalk or word2vec will be just building their RDD<Sequence<T>>,
         * and calling this method for training, instead implementing own routines
         */
    validateConfiguration();
    if (ela == null) {
        try {
            ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (workers > 1) {
        log.info("Repartitioning corpus to {} parts...", workers);
        corpus.repartition(workers);
    }
    if (storageLevel != null)
        corpus.persist(storageLevel);
    final JavaSparkContext sc = new JavaSparkContext(corpus.context());
    // this will have any effect only if wasn't called before, in extension classes
    broadcastEnvironment(sc);
    Counter<Long> finalCounter;
    long numberOfSequences = 0;
    /**
         * Here we s
         */
    if (paramServerConfiguration == null)
        paramServerConfiguration = VoidConfiguration.builder().faultToleranceStrategy(FaultToleranceStrategy.NONE).numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
    isAutoDiscoveryMode = paramServerConfiguration.getShardAddresses() != null && !paramServerConfiguration.getShardAddresses().isEmpty() ? false : true;
    Broadcast<VoidConfiguration> paramServerConfigurationBroadcast = null;
    if (isAutoDiscoveryMode) {
        log.info("Trying auto discovery mode...");
        elementsFreqAccumExtra = corpus.context().accumulator(new ExtraCounter<Long>(), new ExtraElementsFrequenciesAccumulator());
        ExtraCountFunction<T> elementsCounter = new ExtraCountFunction<>(elementsFreqAccumExtra, configuration.isTrainSequenceVectors());
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        finalCounter = elementsFreqAccumExtra.value();
        ExtraCounter<Long> spareReference = (ExtraCounter<Long>) finalCounter;
        // getting list of available hosts
        Set<NetworkInformation> availableHosts = spareReference.getNetworkInformation();
        log.info("availableHosts: {}", availableHosts);
        if (availableHosts.size() > 1) {
            // now we have to pick N shards and optionally N backup nodes, and pass them within configuration bean
            NetworkOrganizer organizer = new NetworkOrganizer(availableHosts, paramServerConfiguration.getNetworkMask());
            paramServerConfiguration.setShardAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards()));
            // backup shards are optional
            if (paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
                paramServerConfiguration.setBackupAddresses(organizer.getSubset(paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.getShardAddresses()));
            }
        } else {
            // for single host (aka driver-only, aka spark-local) just run on loopback interface
            paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + paramServerConfiguration.getUnicastPort()));
            paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
        }
        log.info("Got Shards so far: {}", paramServerConfiguration.getShardAddresses());
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
    } else {
        // update ps configuration with real values where required
        paramServerConfiguration.setNumberOfShards(paramServerConfiguration.getShardAddresses().size());
        paramServerConfiguration.setUseHS(configuration.isUseHierarchicSoftmax());
        paramServerConfiguration.setUseNS(configuration.getNegative() > 0);
        paramServerConfigurationBroadcast = sc.broadcast(paramServerConfiguration);
        // set up freqs accumulator
        elementsFreqAccum = corpus.context().accumulator(new Counter<Long>(), new ElementsFrequenciesAccumulator());
        CountFunction<T> elementsCounter = new CountFunction<>(configurationBroadcast, paramServerConfigurationBroadcast, elementsFreqAccum, configuration.isTrainSequenceVectors());
        // count all sequence elements and their sum
        JavaRDD<Pair<Sequence<T>, Long>> countedCorpus = corpus.map(elementsCounter);
        // just to trigger map function, since we need huffman tree before proceeding
        numberOfSequences = countedCorpus.count();
        // now we grab counter, which contains frequencies for all SequenceElements in corpus
        finalCounter = elementsFreqAccum.value();
    }
    long numberOfElements = (long) finalCounter.totalCount();
    long numberOfUniqueElements = finalCounter.size();
    log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", numberOfSequences, numberOfElements, numberOfUniqueElements);
    /*
         build RDD of reduced SequenceElements, just get rid of labels temporary, stick to some numerical values,
         like index or hashcode. So we could reduce driver memory footprint
         */
    // build huffman tree, and update original RDD with huffman encoding info
    shallowVocabCache = buildShallowVocabCache(finalCounter);
    shallowVocabCacheBroadcast = sc.broadcast(shallowVocabCache);
    // FIXME: probably we need to reconsider this approach
    JavaRDD<T> vocabRDD = corpus.flatMap(new VocabRddFunctionFlat<T>(configurationBroadcast, paramServerConfigurationBroadcast)).distinct();
    vocabRDD.count();
    /**
         * now we initialize Shards with values. That call should be started from driver which is either Client or Shard in standalone mode.
         */
    VoidParameterServer.getInstance().init(paramServerConfiguration, new RoutedTransport(), ela.getTrainingDriver());
    VoidParameterServer.getInstance().initializeSeqVec(configuration.getLayersSize(), (int) numberOfUniqueElements, 119, configuration.getLayersSize() / paramServerConfiguration.getNumberOfShards(), paramServerConfiguration.isUseHS(), paramServerConfiguration.isUseNS());
    // proceed to training
    // also, training function is the place where we invoke ParameterServer
    TrainingFunction<T> trainer = new TrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    PartitionTrainingFunction<T> partitionTrainer = new PartitionTrainingFunction<>(shallowVocabCacheBroadcast, configurationBroadcast, paramServerConfigurationBroadcast);
    if (configuration != null)
        for (int e = 0; e < configuration.getEpochs(); e++) corpus.foreachPartition(partitionTrainer);
    //corpus.foreach(trainer);
    // we're transferring vectors to ExportContainer
    JavaRDD<ExportContainer<T>> exportRdd = vocabRDD.map(new DistributedFunction<T>(paramServerConfigurationBroadcast, configurationBroadcast, shallowVocabCacheBroadcast));
    // at this particular moment training should be pretty much done, and we're good to go for export
    if (exporter != null)
        exporter.export(exportRdd);
    // unpersist, if we've persisten corpus after all
    if (storageLevel != null)
        corpus.unpersist();
    log.info("Training finish, starting cleanup...");
    VoidParameterServer.getInstance().shutdown();
}
Also used : NetworkOrganizer(org.deeplearning4j.spark.models.sequencevectors.utils.NetworkOrganizer) ExportContainer(org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Counter(org.deeplearning4j.berkeley.Counter) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ExtraCounter(org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter) Pair(org.deeplearning4j.berkeley.Pair) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) DL4JInvalidConfigException(org.deeplearning4j.exception.DL4JInvalidConfigException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport)

Example 24 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 25 with JavaSparkContext

use of org.apache.spark.api.java.JavaSparkContext in project deeplearning4j by deeplearning4j.

the class TextPipeline method setup.

private void setup() {
    // Set up accumulators and broadcast stopwords
    this.sc = new JavaSparkContext(corpusRDD.context());
    this.wordFreqAcc = sc.accumulator(new Counter<String>(), new WordFreqAccumulator());
    this.stopWordBroadCast = sc.broadcast(stopWords);
}
Also used : Counter(org.deeplearning4j.berkeley.Counter) WordFreqAccumulator(org.deeplearning4j.spark.text.accumulators.WordFreqAccumulator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext)

Aggregations

JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)251 Test (org.testng.annotations.Test)65 BaseTest (org.broadinstitute.hellbender.utils.test.BaseTest)64 Tuple2 (scala.Tuple2)48 SparkConf (org.apache.spark.SparkConf)46 Test (org.junit.Test)43 ArrayList (java.util.ArrayList)41 GATKRead (org.broadinstitute.hellbender.utils.read.GATKRead)32 List (java.util.List)26 Configuration (org.apache.hadoop.conf.Configuration)23 JavaRDD (org.apache.spark.api.java.JavaRDD)23 File (java.io.File)22 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)20 Collectors (java.util.stream.Collectors)16 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)15 DataSet (org.nd4j.linalg.dataset.DataSet)15 IOException (java.io.IOException)13 SAMFileHeader (htsjdk.samtools.SAMFileHeader)12 RealMatrix (org.apache.commons.math3.linear.RealMatrix)12 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)11