use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.
the class Word2VecTest method testConcepts.
@Test
public void testConcepts() throws Exception {
// These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Path of data part-00000
String dataPath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
// dataPath = "/ext/Temp/part-00000";
// String dataPath = new ClassPathResource("spark_word2vec_test.txt").getFile().getAbsolutePath();
// Read in data
JavaRDD<String> corpus = sc.textFile(dataPath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec word2Vec = new Word2Vec.Builder().setNGrams(1).tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5).learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5).stopWords(Arrays.asList("three")).useUnknown(true).build();
word2Vec.train(corpus);
//word2Vec.setModelUtils(new FlatModelUtils());
System.out.println("UNK: " + word2Vec.getWordVectorMatrix("UNK"));
InMemoryLookupTable<VocabWord> table = (InMemoryLookupTable<VocabWord>) word2Vec.lookupTable();
double sim = word2Vec.similarity("day", "night");
System.out.println("day/night similarity: " + sim);
/*
System.out.println("Hornjo: " + word2Vec.getWordVectorMatrix("hornjoserbsce"));
System.out.println("carro: " + word2Vec.getWordVectorMatrix("carro"));
Collection<String> portu = word2Vec.wordsNearest("carro", 10);
printWords("carro", portu, word2Vec);
portu = word2Vec.wordsNearest("davi", 10);
printWords("davi", portu, word2Vec);
System.out.println("---------------------------------------");
*/
Collection<String> words = word2Vec.wordsNearest("day", 10);
printWords("day", words, word2Vec);
assertTrue(words.contains("night"));
assertTrue(words.contains("week"));
assertTrue(words.contains("year"));
sim = word2Vec.similarity("two", "four");
System.out.println("two/four similarity: " + sim);
words = word2Vec.wordsNearest("two", 10);
printWords("two", words, word2Vec);
// three should be absent due to stopWords
assertFalse(words.contains("three"));
assertTrue(words.contains("five"));
assertTrue(words.contains("four"));
sc.stop();
// test serialization
File tempFile = File.createTempFile("temp", "tmp");
tempFile.deleteOnExit();
int idx1 = word2Vec.vocab().wordFor("day").getIndex();
INDArray array1 = word2Vec.getWordVectorMatrix("day").dup();
VocabWord word1 = word2Vec.vocab().elementAtIndex(0);
WordVectorSerializer.writeWordVectors(word2Vec.getLookupTable(), tempFile);
WordVectors vectors = WordVectorSerializer.loadTxtVectors(tempFile);
VocabWord word2 = ((VocabCache<VocabWord>) vectors.vocab()).elementAtIndex(0);
VocabWord wordIT = ((VocabCache<VocabWord>) vectors.vocab()).wordFor("it");
int idx2 = vectors.vocab().wordFor("day").getIndex();
INDArray array2 = vectors.getWordVectorMatrix("day").dup();
System.out.println("word 'i': " + word2);
System.out.println("word 'it': " + wordIT);
assertEquals(idx1, idx2);
assertEquals(word1, word2);
assertEquals(array1, array2);
}
use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testUnifiedLoaderBinary.
/**
* This method tests binary file loading via unified loader
*
* @throws Exception
*/
@Test
public void testUnifiedLoaderBinary() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
WordVectors vectorsLive = WordVectorSerializer.loadGoogleModel(binaryFile, true);
WordVectors vectorsStatic = WordVectorSerializer.readWord2VecModel(binaryFile, false);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testStaticLoaderArchive.
/**
* This method tests ZIP file loading as static model
*
* @throws Exception
*/
@Test
public void testStaticLoaderArchive() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();
WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(w2v);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("night");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testWriteWordVectors.
@Test
@Ignore
public void testWriteWordVectors() throws IOException {
WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
WordVectorSerializer.writeWordVectors(lookupTable, lookupCache, pathToWriteto);
WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
assertTrue(wordVector1.length == 300);
assertTrue(wordVector2.length == 300);
assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
use of org.deeplearning4j.models.embeddings.wordvectors.WordVectors in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testStaticLoaderBinary.
/**
* This method tests binary file loading as static model
*
* @throws Exception
*/
@Test
public void testStaticLoaderBinary() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
WordVectors vectorsLive = WordVectorSerializer.loadGoogleModel(binaryFile, true);
WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(binaryFile);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
Aggregations