Search in sources :

Example 1 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class Glove method train.

/**
     * Train on the corpus
     * @param rdd the rdd to train
     * @return the vocab and weights
     */
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
    // Each `train()` can use different parameters
    final JavaSparkContext sc = new JavaSparkContext(rdd.context());
    final SparkConf conf = sc.getConf();
    final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
    final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
    final double negative = assignVar(NEGATIVE, conf, Double.class);
    final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
    final int window = assignVar(WINDOW, conf, Integer.class);
    final double alpha = assignVar(ALPHA, conf, Double.class);
    final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
    final int iterations = assignVar(ITERATIONS, conf, Integer.class);
    final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
    final String tokenizer = assignVar(TOKENIZER, conf, String.class);
    final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
    final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
    Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {

        {
            put("numWords", numWords);
            put("nGrams", nGrams);
            put("tokenizer", tokenizer);
            put("tokenPreprocessor", tokenPreprocessor);
            put("removeStop", removeStop);
        }
    };
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count
    Long totalWordCount = pipeline.getTotalWordCount();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
    final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
    final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
    gloveWeightLookupTable.resetWeights();
    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
    log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
    Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
    List<Triple<String, String, Double>> counts = new ArrayList<>();
    while (pair2.hasNext()) {
        Pair<String, String> next = pair2.next();
        if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
            coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
        }
        counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
    }
    log.info("Calculated co occurrences");
    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {

        @Override
        public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
            return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
        }
    });
    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {

        @Override
        public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
            VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
            VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
            return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
        }
    });
    for (int i = 0; i < iterations; i++) {
        JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {

            @Override
            public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
                VocabWord w1 = vocabWordTuple2Tuple2._1();
                VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                INDArray bias = gloveWeightLookupTable.getBias();
                double score = vocabWordTuple2Tuple2._2()._2();
                double xMax = gloveWeightLookupTable.getxMax();
                double maxCount = gloveWeightLookupTable.getMaxCount();
                //w1 * w2 + bias
                double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
                double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
                double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
                if (Double.isNaN(fDiff))
                    fDiff = Nd4j.EPS_THRESHOLD;
                //amount of change
                double gradient = fDiff;
                Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
                Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
                return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
            }
        });
        List<GloveChange> gloveChanges = change.collect();
        double error = 0.0;
        for (GloveChange change2 : gloveChanges) {
            change2.apply(gloveWeightLookupTable);
            error += change2.getError();
        }
        List l = pairsVocab.collect();
        Collections.shuffle(l);
        pairsVocab = sc.parallelizePairs(l);
        log.info("Error at iteration " + i + " was " + error);
    }
    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
Also used : CoOccurrenceCounts(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Triple(org.deeplearning4j.berkeley.Triple) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterMap(org.deeplearning4j.berkeley.CounterMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) CoOccurrenceCalculator(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) SparkConf(org.apache.spark.SparkConf)

Example 2 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 3 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testBuildVocabWordListRDD.

@Test
public void testBuildVocabWordListRDD() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
    List<List<VocabWord>> vocabWordList = vocabWordListRDD.collect();
    List<VocabWord> firstSentenceVocabList = vocabWordList.get(0);
    List<VocabWord> secondSentenceVocabList = vocabWordList.get(1);
    System.out.println(Arrays.deepToString(firstSentenceVocabList.toArray()));
    List<String> firstSentenceTokenList = new ArrayList<>();
    List<String> secondSentenceTokenList = new ArrayList<>();
    for (VocabWord v : firstSentenceVocabList) {
        if (v != null) {
            firstSentenceTokenList.add(v.getWord());
        }
    }
    for (VocabWord v : secondSentenceVocabList) {
        if (v != null) {
            secondSentenceTokenList.add(v.getWord());
        }
    }
    assertEquals(pipeline.getTotalWordCount(), 9, 0);
    assertEquals(sentenceCountRDD.collect().get(0).get(), 6);
    assertEquals(sentenceCountRDD.collect().get(1).get(), 3);
    assertTrue(firstSentenceTokenList.containsAll(Arrays.asList("strange", "strange", "world")));
    assertTrue(secondSentenceTokenList.containsAll(Arrays.asList("flowers", "red")));
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.junit.Test)

Example 4 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testZipFunction1.

/**
     * This test checked generations retrieved using stopWords
     *
     * @throws Exception
     */
@Test
public void testZipFunction1() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
    List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
    List<VocabWord> vocabWordsList1 = lst.get(0)._1();
    Long cumSumSize1 = lst.get(0)._2();
    assertEquals(3, vocabWordsList1.size());
    assertEquals(vocabWordsList1.get(0).getWord(), "strange");
    assertEquals(vocabWordsList1.get(1).getWord(), "strange");
    assertEquals(vocabWordsList1.get(2).getWord(), "world");
    assertEquals(cumSumSize1, 6L, 0);
    List<VocabWord> vocabWordsList2 = lst.get(1)._1();
    Long cumSumSize2 = lst.get(1)._2();
    assertEquals(2, vocabWordsList2.size());
    assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
    assertEquals(vocabWordsList2.get(1).getWord(), "red");
    assertEquals(cumSumSize2, 9L, 0);
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) Tuple2(scala.Tuple2) AtomicLong(java.util.concurrent.atomic.AtomicLong) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.junit.Test)

Example 5 with TextPipeline

use of org.deeplearning4j.spark.text.functions.TextPipeline in project deeplearning4j by deeplearning4j.

the class TextPipelineTest method testZipFunction2.

@Test
public void testZipFunction2() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    //  word2vec.setRemoveStop(false);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
    JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
    List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
    List<VocabWord> vocabWordsList1 = lst.get(0)._1();
    Long cumSumSize1 = lst.get(0)._2();
    assertEquals(6, vocabWordsList1.size());
    assertEquals(vocabWordsList1.get(0).getWord(), "this");
    assertEquals(vocabWordsList1.get(1).getWord(), "is");
    assertEquals(vocabWordsList1.get(2).getWord(), "a");
    assertEquals(vocabWordsList1.get(3).getWord(), "strange");
    assertEquals(vocabWordsList1.get(4).getWord(), "strange");
    assertEquals(vocabWordsList1.get(5).getWord(), "world");
    assertEquals(cumSumSize1, 6L, 0);
    List<VocabWord> vocabWordsList2 = lst.get(1)._1();
    Long cumSumSize2 = lst.get(1)._2();
    assertEquals(vocabWordsList2.size(), 3);
    assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
    assertEquals(vocabWordsList2.get(1).getWord(), "are");
    assertEquals(vocabWordsList2.get(2).getWord(), "red");
    assertEquals(cumSumSize2, 9L, 0);
    sc.stop();
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) Tuple2(scala.Tuple2) AtomicLong(java.util.concurrent.atomic.AtomicLong) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.junit.Test)

Aggregations

JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)15 TextPipeline (org.deeplearning4j.spark.text.functions.TextPipeline)15 Test (org.junit.Test)13 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)10 AtomicLong (java.util.concurrent.atomic.AtomicLong)8 CountCumSum (org.deeplearning4j.spark.text.functions.CountCumSum)6 Pair (org.deeplearning4j.berkeley.Pair)4 Tuple2 (scala.Tuple2)4 Huffman (org.deeplearning4j.models.word2vec.Huffman)2 VocabCache (org.deeplearning4j.models.word2vec.wordstore.VocabCache)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 SparkConf (org.apache.spark.SparkConf)1 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)1 CounterMap (org.deeplearning4j.berkeley.CounterMap)1 Triple (org.deeplearning4j.berkeley.Triple)1