Search in sources :

Example 6 with VectorsConfiguration

use of org.deeplearning4j.models.embeddings.loader.VectorsConfiguration in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testGlove1.

@Ignore
@Test
public void testGlove1() throws Exception {
    logger.info("Max available memory: " + Runtime.getRuntime().maxMemory());
    ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
    File file = resource.getFile();
    BasicLineIterator underlyingIterator = new BasicLineIterator(file);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
    AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setWindow(5);
    configuration.setLearningRate(0.06);
    configuration.setLayersSize(100);
    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(configuration).iterate(sequenceIterator).iterations(1).epochs(45).elementsLearningAlgorithm(new GloVe.Builder<VocabWord>().shuffle(true).symmetric(true).learningRate(0.05).alpha(0.75).xMax(100.0).build()).resetModel(true).trainElementsRepresentation(true).trainSequencesRepresentation(false).build();
    vectors.fit();
    double sim = vectors.similarity("day", "night");
    logger.info("Day/night similarity: " + sim);
    sim = vectors.similarity("day", "another");
    logger.info("Day/another similarity: " + sim);
    sim = vectors.similarity("night", "year");
    logger.info("Night/year similarity: " + sim);
    sim = vectors.similarity("night", "me");
    logger.info("Night/me similarity: " + sim);
    sim = vectors.similarity("day", "know");
    logger.info("Day/know similarity: " + sim);
    sim = vectors.similarity("best", "police");
    logger.info("Best/police similarity: " + sim);
    Collection<String> labels = vectors.wordsNearest("day", 10);
    logger.info("Nearest labels to 'day': " + labels);
    sim = vectors.similarity("day", "night");
    assertTrue(sim > 0.6d);
}
Also used : GloVe(org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe) BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)6 Test (org.junit.Test)6 File (java.io.File)5 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)4 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)4 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)4 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)4 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)4 ClassPathResource (org.datavec.api.util.ClassPathResource)3 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)3 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)3 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)3 Ignore (org.junit.Ignore)3 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 GloVe (org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe)1 DM (org.deeplearning4j.models.embeddings.learning.impl.sequence.DM)1 FlatModelUtils (org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils)1 ParagraphVectors (org.deeplearning4j.models.paragraphvectors.ParagraphVectors)1