use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testConcepts.
@Test
public void testConcepts() throws Exception {
// These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
// dataPath = "/ext/Temp/part-00000";
// String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
word2Vec.train(corpus);
//word2Vec.setModelUtils(new FlatModelUtils());
System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
double sim = word2Vec.similarity("day", "night");
System.out.println("day/night similarity: " + sim);
/*
System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
Collection<String> portu = word2Vec.wordsNearest("carro", 10);
printWords("carro", portu, word2Vec);
portu = word2Vec.wordsNearest("davi", 10);
printWords("davi", portu, word2Vec);
System.out.println("---------------------------------------");
*/
Collection<String> words = word2Vec.wordsNearest("day", 10);
printWords("day", words, word2Vec);
assertTrue(words.contains("night"));
assertTrue(words.contains("week"));
assertTrue(words.contains("year"));
sim = word2Vec.similarity("two", "four");
System.out.println("two/four similarity: " + sim);
words = word2Vec.wordsNearest("two", 10);
printWords("two", words, word2Vec);
// three should be absent due to stopWords
assertFalse(words.contains("three"));
assertTrue(words.contains("five"));
assertTrue(words.contains("four"));
sc.stop();
// test serialization
File tempFile = File.createTempFile("temp", "tmp");
tempFile.deleteOnExit();
int idx1 = word2Vec.vocab().wordFor("day").getIndex();
INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
int idx2 = vectors.vocab().wordFor("day").getIndex();
INDArray array2 = vectors.getWordVectorMatrix("day").dup();
System.out.println("word 'i': " + word2);
System.out.println("word 'it': " + wordIT);
assertEquals(idx1, idx2);
assertEquals(word1, word2);
assertEquals(array1, array2);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectorsTest method testFrequenciesCount.
@Test
public void testFrequenciesCount() throws Exception {
JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
seqVec.fitSequences(sequences);
Counter<Long> counter = seqVec.getCounter();
// element "0" should have frequency of 20
assertEquals(20, counter.getCount(0L), 1e-5);
// elements 1 - 9 should have frequencies of 10
for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) {
assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5);
}
VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache();
assertEquals(10, shallowVocab.numWords());
ShallowSequenceElement zero = shallowVocab.tokenFor(0L);
ShallowSequenceElement first = shallowVocab.tokenFor(1L);
assertNotEquals(null, zero);
assertEquals(20.0, zero.getElementFrequency(), 1e-5);
assertEquals(0, zero.getIndex());
assertEquals(10.0, first.getElementFrequency(), 1e-5);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectorsTest method setUp.
@Before
public void setUp() throws Exception {
if (sequencesCyclic == null) {
sequencesCyclic = new ArrayList<>();
// 10 sequences in total
for (int с = 0; с < 10; с++) {
Sequence<VocabWord> sequence = new Sequence<>();
for (int e = 0; e < 10; e++) {
// we will have 9 equal elements, with total frequency of 10
sequence.addElement(new VocabWord(1.0, "" + e, (long) e));
}
// and 1 element with frequency of 20
sequence.addElement(new VocabWord(1.0, "0", 0L));
sequencesCyclic.add(sequence);
}
}
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testBuildVocabWordListRDD.
@Test
public void testBuildVocabWordListRDD() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
List<List<VocabWord>> vocabWordList = vocabWordListRDD.collect();
List<VocabWord> firstSentenceVocabList = vocabWordList.get(0);
List<VocabWord> secondSentenceVocabList = vocabWordList.get(1);
System.out.println(Arrays.deepToString(firstSentenceVocabList.toArray()));
List<String> firstSentenceTokenList = new ArrayList<>();
List<String> secondSentenceTokenList = new ArrayList<>();
for (VocabWord v : firstSentenceVocabList) {
if (v != null) {
firstSentenceTokenList.add(v.getWord());
}
}
for (VocabWord v : secondSentenceVocabList) {
if (v != null) {
secondSentenceTokenList.add(v.getWord());
}
}
assertEquals(pipeline.getTotalWordCount(), 9, 0);
assertEquals(sentenceCountRDD.collect().get(0).get(), 6);
assertEquals(sentenceCountRDD.collect().get(1).get(), 3);
assertTrue(firstSentenceTokenList.containsAll(Arrays.asList("strange", "strange", "world")));
assertTrue(secondSentenceTokenList.containsAll(Arrays.asList("flowers", "red")));
sc.stop();
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class TextPipelineTest method testZipFunction1.
/**
* This test checked generations retrieved using stopWords
*
* @throws Exception
*/
@Test
public void testZipFunction1() throws Exception {
JavaSparkContext sc = getContext();
JavaRDD<String> corpusRDD = getCorpusRDD(sc);
// word2vec.setRemoveStop(false);
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());
TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();
JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCountCumSumRDD);
List<Tuple2<List<VocabWord>, Long>> lst = vocabWordListSentenceCumSumRDD.collect();
List<VocabWord> vocabWordsList1 = lst.get(0)._1();
Long cumSumSize1 = lst.get(0)._2();
assertEquals(3, vocabWordsList1.size());
assertEquals(vocabWordsList1.get(0).getWord(), "strange");
assertEquals(vocabWordsList1.get(1).getWord(), "strange");
assertEquals(vocabWordsList1.get(2).getWord(), "world");
assertEquals(cumSumSize1, 6L, 0);
List<VocabWord> vocabWordsList2 = lst.get(1)._1();
Long cumSumSize2 = lst.get(1)._2();
assertEquals(2, vocabWordsList2.size());
assertEquals(vocabWordsList2.get(0).getWord(), "flowers");
assertEquals(vocabWordsList2.get(1).getWord(), "red");
assertEquals(cumSumSize2, 9L, 0);
sc.stop();
}
Aggregations