Search in sources :

Example 1 with FlatModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testDeepWalk.

@Test
@Ignore
public void testDeepWalk() throws Exception {
    Heartbeat.getInstance().disableHeartbeat();
    AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
    Graph<Blogger, Double> graph = buildGraph();
    GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05).setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
    /*
        GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
                .setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
                .setWalkLength(40)
                .setWalkDirection(WalkDirection.RANDOM)
                .setRestartProbability(0.05)
                .build();
        */
    GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker).shuffleOnReset(true).setVocabCache(vocabCache).build();
    Blogger blogger = graph.getVertex(0).getValue();
    assertEquals(119, blogger.getElementFrequency(), 0.001);
    logger.info("Blogger: " + blogger);
    AbstractSequenceIterator<Blogger> sequenceIterator = new AbstractSequenceIterator.Builder<>(graphTransformer).build();
    WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).seed(42).build();
    lookupTable.resetWeights(true);
    SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration()).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(1000).iterations(1).epochs(10).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).elementsLearningAlgorithm(new SkipGram<Blogger>()).learningRate(0.025).layerSize(150).sampling(0).negativeSample(0).windowSize(4).workers(6).seed(42).build();
    vectors.fit();
    vectors.setModelUtils(new FlatModelUtils());
    //     logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
    double sim = vectors.similarity("12", "72");
    Collection<String> list = vectors.wordsNearest("12", 20);
    logger.info("12->72: " + sim);
    printWords("12", list, vectors);
    assertTrue(sim > 0.10);
    assertFalse(Double.isNaN(sim));
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 2 with FlatModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils in project deeplearning4j by deeplearning4j.

the class Word2VecTests method testW2VnegativeOnRestore.

@Test
public void testW2VnegativeOnRestore() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).negativeSample(10).epochs(1).windowSize(5).useHierarchicSoftmax(false).allowParallelTokenization(true).modelUtils(new FlatModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
    assertEquals(false, vec.getConfiguration().isUseHierarchicSoftmax());
    log.info("Fit 1");
    vec.fit();
    File tmpFile = File.createTempFile("temp", "file");
    tmpFile.deleteOnExit();
    WordVectorSerializer.writeWord2VecModel(vec, tmpFile);
    iter.reset();
    Word2Vec restoredVec = WordVectorSerializer.readWord2VecModel(tmpFile, true);
    restoredVec.setTokenizerFactory(t);
    restoredVec.setSentenceIterator(iter);
    assertEquals(false, restoredVec.getConfiguration().isUseHierarchicSoftmax());
    assertTrue(restoredVec.getModelUtils() instanceof FlatModelUtils);
    assertTrue(restoredVec.getConfiguration().isAllowParallelTokenization());
    log.info("Fit 2");
    restoredVec.fit();
    iter.reset();
    restoredVec = WordVectorSerializer.readWord2VecModel(tmpFile, false);
    restoredVec.setTokenizerFactory(t);
    restoredVec.setSentenceIterator(iter);
    assertEquals(false, restoredVec.getConfiguration().isUseHierarchicSoftmax());
    assertTrue(restoredVec.getModelUtils() instanceof BasicModelUtils);
    log.info("Fit 3");
    restoredVec.fit();
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) BasicModelUtils(org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) File(java.io.File) Test(org.junit.Test)

Example 3 with FlatModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils in project deeplearning4j by deeplearning4j.

the class Word2VecTests method testUnknown1.

@Test
public void testUnknown1() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(10).useUnknown(true).unknownElement(new VocabWord(1.0, "PEWPEW")).iterations(1).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new CBOW<VocabWord>()).epochs(1).windowSize(5).useHierarchicSoftmax(true).allowParallelTokenization(true).modelUtils(new FlatModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
    vec.fit();
    assertTrue(vec.hasWord("PEWPEW"));
    assertTrue(vec.getVocab().containsWord("PEWPEW"));
    INDArray unk = vec.getWordVectorMatrix("PEWPEW");
    assertNotEquals(null, unk);
    File tempFile = File.createTempFile("temp", "file");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeWord2VecModel(vec, tempFile);
    log.info("Original configuration: {}", vec.getConfiguration());
    Word2Vec restored = WordVectorSerializer.readWord2VecModel(tempFile);
    assertTrue(restored.hasWord("PEWPEW"));
    assertTrue(restored.getVocab().containsWord("PEWPEW"));
    INDArray unk_restored = restored.getWordVectorMatrix("PEWPEW");
    assertEquals(unk, unk_restored);
    // now we're getting some junk word
    INDArray random = vec.getWordVectorMatrix("hhsd7d7sdnnmxc_SDsda");
    INDArray randomRestored = restored.getWordVectorMatrix("hhsd7d7sdnnmxc_SDsda");
    log.info("Restored configuration: {}", restored.getConfiguration());
    assertEquals(unk, random);
    assertEquals(unk, randomRestored);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ArrayList(java.util.ArrayList) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) UimaSentenceIterator(org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) INDArray(org.nd4j.linalg.api.ndarray.INDArray) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) File(java.io.File) Test(org.junit.Test)

Example 4 with FlatModelUtils

use of org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils in project deeplearning4j by deeplearning4j.

the class Word2VecTest method testPortugeseW2V.

@Test
@Ignore
public void testPortugeseW2V() throws Exception {
    WordVectors word2Vec = WordVectorSerializer.loadTxtVectors(new File("/ext/Temp/para.txt"));
    word2Vec.setModelUtils(new FlatModelUtils());
    Collection<String> portu = word2Vec.wordsNearest("carro", 10);
    printWords("carro", portu, word2Vec);
    portu = word2Vec.wordsNearest("davi", 10);
    printWords("davi", portu, word2Vec);
}
Also used : FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) WordVectors(org.deeplearning4j.models.embeddings.wordvectors.WordVectors) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

FlatModelUtils (org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils)4 Test (org.junit.Test)4 File (java.io.File)3 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)2 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)2 UimaSentenceIterator (org.deeplearning4j.text.sentenceiterator.UimaSentenceIterator)2 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)2 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)2 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)2 Ignore (org.junit.Ignore)2 ArrayList (java.util.ArrayList)1 SkipGram (org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram)1 VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)1 BasicModelUtils (org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils)1 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)1 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)1 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1