Search in sources :

Example 11 with SparkConf

use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.

the class TestKryoWarning method testKryoMessageMLNIncorrectConfig.

@Test
@Ignore
public void testKryoMessageMLNIncorrectConfig() {
    //Should print warning message
    SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
    doTestMLN(sparkConf);
}
Also used : SparkConf(org.apache.spark.SparkConf) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 12 with SparkConf

use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.

the class Glove method train.

/**
     * Train on the corpus
     * @param rdd the rdd to train
     * @return the vocab and weights
     */
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
    // Each `train()` can use different parameters
    final JavaSparkContext sc = new JavaSparkContext(rdd.context());
    final SparkConf conf = sc.getConf();
    final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
    final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
    final double negative = assignVar(NEGATIVE, conf, Double.class);
    final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
    final int window = assignVar(WINDOW, conf, Integer.class);
    final double alpha = assignVar(ALPHA, conf, Double.class);
    final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
    final int iterations = assignVar(ITERATIONS, conf, Integer.class);
    final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
    final String tokenizer = assignVar(TOKENIZER, conf, String.class);
    final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
    final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
    Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {

        {
            put("numWords", numWords);
            put("nGrams", nGrams);
            put("tokenizer", tokenizer);
            put("tokenPreprocessor", tokenPreprocessor);
            put("removeStop", removeStop);
        }
    };
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count
    Long totalWordCount = pipeline.getTotalWordCount();
    VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
    JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
    final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
    vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
    final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
    gloveWeightLookupTable.resetWeights();
    gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
    gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
    log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
    CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
    Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
    List<Triple<String, String, Double>> counts = new ArrayList<>();
    while (pair2.hasNext()) {
        Pair<String, String> next = pair2.next();
        if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
            coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
        }
        counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
    }
    log.info("Calculated co occurrences");
    JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
    JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {

        @Override
        public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
            return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
        }
    });
    JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {

        @Override
        public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
            VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
            VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
            return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
        }
    });
    for (int i = 0; i < iterations; i++) {
        JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {

            @Override
            public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
                VocabWord w1 = vocabWordTuple2Tuple2._1();
                VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
                INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
                INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
                INDArray bias = gloveWeightLookupTable.getBias();
                double score = vocabWordTuple2Tuple2._2()._2();
                double xMax = gloveWeightLookupTable.getxMax();
                double maxCount = gloveWeightLookupTable.getMaxCount();
                //w1 * w2 + bias
                double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
                prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
                double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
                double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
                if (Double.isNaN(fDiff))
                    fDiff = Nd4j.EPS_THRESHOLD;
                //amount of change
                double gradient = fDiff;
                Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
                Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
                return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
            }
        });
        List<GloveChange> gloveChanges = change.collect();
        double error = 0.0;
        for (GloveChange change2 : gloveChanges) {
            change2.apply(gloveWeightLookupTable);
            error += change2.getError();
        }
        List l = pairsVocab.collect();
        Collections.shuffle(l);
        pairsVocab = sc.parallelizePairs(l);
        log.info("Error at iteration " + i + " was " + error);
    }
    return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
Also used : CoOccurrenceCounts(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) Triple(org.deeplearning4j.berkeley.Triple) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterMap(org.deeplearning4j.berkeley.CounterMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) CoOccurrenceCalculator(org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) GloveWeightLookupTable(org.deeplearning4j.models.glove.GloveWeightLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) SparkConf(org.apache.spark.SparkConf)

Example 13 with SparkConf

use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.

the class Word2VecTest method testSparkW2VonBiggerCorpus.

@Ignore
@Test
public void testSparkW2VonBiggerCorpus() throws Exception {
    SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest").set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g").set("spark.executor.memory", "8g");
    // Set SparkContext
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    // Path of data part-00000
    //String dataPath = new ClassPathResource("/big/raw_sentences.txt").getFile().getAbsolutePath();
    //        String dataPath = "/ext/Temp/SampleRussianCorpus.txt";
    String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
    // Read in data
    JavaRDD<String> corpus = sc.textFile(dataPath);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new LowCasePreProcessor());
    Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(3).useAdaGrad(false).layerSize(100).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).useUnknown(true).build();
    word2Vec.train(corpus);
    sc.stop();
    WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), "/ext/Temp/sparkRuModel.txt");
}
Also used : DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) LowCasePreProcessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.LowCasePreProcessor) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SparkConf(org.apache.spark.SparkConf) ClassPathResource(org.datavec.api.util.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 14 with SparkConf

use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.

the class Word2VecTest method testConcepts.

@Test
public void testConcepts() throws Exception {
    // These are all default values for word2vec
    SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
    // Set SparkContext
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    // Path of data part-00000
    String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
    //        dataPath = "/ext/Temp/part-00000";
    //        String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
    // Read in data
    JavaRDD<String> corpus = sc.textFile(dataPath);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
    word2Vec.train(corpus);
    //word2Vec.setModelUtils(new FlatModelUtils());
    System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
    InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
    double sim = word2Vec.similarity("day", "night");
    System.out.println("day/night similarity: " + sim);
    /*
        System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
        System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
        
        Collection<String> portu = word2Vec.wordsNearest("carro", 10);
        printWords("carro", portu, word2Vec);
        
        portu = word2Vec.wordsNearest("davi", 10);
        printWords("davi", portu, word2Vec);
        
        System.out.println("---------------------------------------");
        */
    Collection<String> words = word2Vec.wordsNearest("day", 10);
    printWords("day", words, word2Vec);
    assertTrue(words.contains("night"));
    assertTrue(words.contains("week"));
    assertTrue(words.contains("year"));
    sim = word2Vec.similarity("two", "four");
    System.out.println("two/four similarity: " + sim);
    words = word2Vec.wordsNearest("two", 10);
    printWords("two", words, word2Vec);
    // three should be absent due to stopWords
    assertFalse(words.contains("three"));
    assertTrue(words.contains("five"));
    assertTrue(words.contains("four"));
    sc.stop();
    // test serialization
    File tempFile = File.createTempFile("temp", "tmp");
    tempFile.deleteOnExit();
    int idx1 = word2Vec.vocab().wordFor("day").getIndex();
    INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
    VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
    WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
    WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
    VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
    VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
    int idx2 = vectors.vocab().wordFor("day").getIndex();
    INDArray array2 = vectors.getWordVectorMatrix("day").dup();
    System.out.println("word 'i': " + word2);
    System.out.println("word 'it': " + wordIT);
    assertEquals(idx1, idx2);
    assertEquals(word1, word2);
    assertEquals(array1, array2);
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) SparkConf(org.apache.spark.SparkConf) File(java.io.File) Test(org.junit.Test)

Example 15 with SparkConf

use of org.apache.spark.SparkConf in project deeplearning4j by deeplearning4j.

the class SparkSequenceVectorsTest method setUp.

@Before
public void setUp() throws Exception {
    if (sequencesCyclic == null) {
        sequencesCyclic = new ArrayList<>();
        // 10 sequences in total
        for (int с = 0; с < 10; с++) {
            Sequence<VocabWord> sequence = new Sequence<>();
            for (int e = 0; e < 10; e++) {
                // we will have 9 equal elements, with total frequency of 10
                sequence.addElement(new VocabWord(1.0, "" + e, (long) e));
            }
            // and 1 element with frequency of 20
            sequence.addElement(new VocabWord(1.0, "0", 0L));
            sequencesCyclic.add(sequence);
        }
    }
    SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
    sc = new JavaSparkContext(sparkConf);
}
Also used : VocabWord(org.deeplearning4j.models.word2vec.VocabWord) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SparkConf(org.apache.spark.SparkConf) Before(org.junit.Before)

Aggregations

SparkConf (org.apache.spark.SparkConf)83 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)46 Test (org.junit.Test)21 ArrayList (java.util.ArrayList)20 Configuration (org.apache.hadoop.conf.Configuration)20 Tuple2 (scala.Tuple2)15 Graph (uk.gov.gchq.gaffer.graph.Graph)13 DataOutputStream (java.io.DataOutputStream)11 File (java.io.File)10 HashSet (java.util.HashSet)10 ByteArrayOutputStream (org.apache.commons.io.output.ByteArrayOutputStream)10 Edge (uk.gov.gchq.gaffer.data.element.Edge)10 Element (uk.gov.gchq.gaffer.data.element.Element)10 Entity (uk.gov.gchq.gaffer.data.element.Entity)10 User (uk.gov.gchq.gaffer.user.User)10 Ignore (org.junit.Ignore)6 HBaseConfiguration (org.apache.hadoop.hbase.HBaseConfiguration)5 JavaHBaseContext (org.apache.hadoop.hbase.spark.JavaHBaseContext)5 Test (org.testng.annotations.Test)5 AddElements (uk.gov.gchq.gaffer.operation.impl.add.AddElements)5