Search in sources :

Example 1 with VectorsConfiguration

use of org.deeplearning4j.models.embeddings.loader.VectorsConfiguration in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testInternalVocabConstruction.

@Test
public void testInternalVocabConstruction() throws Exception {
    ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
    File file = resource.getFile();
    BasicLineIterator underlyingIterator = new BasicLineIterator(file);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
    AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5).iterate(sequenceIterator).batchSize(250).iterations(1).epochs(1).resetModel(false).trainElementsRepresentation(true).build();
    logger.info("Fitting model...");
    vectors.fit();
    logger.info("Model ready...");
    double sim = vectors.similarity("day", "night");
    logger.info("Day/night similarity: " + sim);
    assertTrue(sim > 0.6d);
    Collection<String> labels = vectors.wordsNearest("day", 10);
    logger.info("Nearest labels to 'day': " + labels);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) File(java.io.File) Test(org.junit.Test)

Example 2 with VectorsConfiguration

use of org.deeplearning4j.models.embeddings.loader.VectorsConfiguration in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testDeepWalk.

@Test
@Ignore
public void testDeepWalk() throws Exception {
    Heartbeat.getInstance().disableHeartbeat();
    AbstractCache<Blogger> vocabCache = new AbstractCache.Builder<Blogger>().build();
    Graph<Blogger, Double> graph = buildGraph();
    GraphWalker<Blogger> walker = new PopularityWalker.Builder<>(graph).setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED).setWalkLength(40).setWalkDirection(WalkDirection.FORWARD_UNIQUE).setRestartProbability(0.05).setPopularitySpread(10).setPopularityMode(PopularityMode.MAXIMUM).setSpreadSpectrum(SpreadSpectrum.PROPORTIONAL).build();
    /*
        GraphWalker<Blogger> walker = new RandomWalker.Builder<Blogger>(graph)
                .setNoEdgeHandling(NoEdgeHandling.RESTART_ON_DISCONNECTED)
                .setWalkLength(40)
                .setWalkDirection(WalkDirection.RANDOM)
                .setRestartProbability(0.05)
                .build();
        */
    GraphTransformer<Blogger> graphTransformer = new GraphTransformer.Builder<>(graph).setGraphWalker(walker).shuffleOnReset(true).setVocabCache(vocabCache).build();
    Blogger blogger = graph.getVertex(0).getValue();
    assertEquals(119, blogger.getElementFrequency(), 0.001);
    logger.info("Blogger: " + blogger);
    AbstractSequenceIterator<Blogger> sequenceIterator = new AbstractSequenceIterator.Builder<>(graphTransformer).build();
    WeightLookupTable<Blogger> lookupTable = new InMemoryLookupTable.Builder<Blogger>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).seed(42).build();
    lookupTable.resetWeights(true);
    SequenceVectors<Blogger> vectors = new SequenceVectors.Builder<Blogger>(new VectorsConfiguration()).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(1000).iterations(1).epochs(10).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).elementsLearningAlgorithm(new SkipGram<Blogger>()).learningRate(0.025).layerSize(150).sampling(0).negativeSample(0).windowSize(4).workers(6).seed(42).build();
    vectors.fit();
    vectors.setModelUtils(new FlatModelUtils());
    //     logger.info("12: " + Arrays.toString(vectors.getWordVector("12")));
    double sim = vectors.similarity("12", "72");
    Collection<String> list = vectors.wordsNearest("12", 20);
    logger.info("12->72: " + sim);
    printWords("12", list, vectors);
    assertTrue(sim > 0.10);
    assertFalse(Double.isNaN(sim));
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) FlatModelUtils(org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 3 with VectorsConfiguration

use of org.deeplearning4j.models.embeddings.loader.VectorsConfiguration in project deeplearning4j by deeplearning4j.

the class WordVectorSerializerTest method testParaVecSerialization1.

@Test
public void testParaVecSerialization1() throws Exception {
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(14123);
    configuration.setLayersSize(156);
    INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
    INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
    for (int i = 0; i < 100; i++) {
        VocabWord word = new VocabWord((float) i, "word_" + i);
        List<Integer> points = new ArrayList<>();
        List<Byte> codes = new ArrayList<>();
        int num = org.apache.commons.lang3.RandomUtils.nextInt(1, 20);
        for (int x = 0; x < num; x++) {
            points.add(org.apache.commons.lang3.RandomUtils.nextInt(1, 100000));
            codes.add(org.apache.commons.lang3.RandomUtils.nextBytes(10)[0]);
        }
        if (RandomUtils.nextInt(10) < 3) {
            word.markAsLabel(true);
        }
        word.setIndex(i);
        word.setPoints(points);
        word.setCodes(codes);
        cache.addToken(word);
        cache.addWordToIndex(i, word.getLabel());
    }
    InMemoryLookupTable<VocabWord> lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().vectorLength(configuration.getLayersSize()).cache(cache).build();
    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);
    ParagraphVectors originalVectors = new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();
    File tempFile = File.createTempFile("paravec", "tests");
    tempFile.deleteOnExit();
    WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);
    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);
    InMemoryLookupTable<VocabWord> restoredLookupTable = (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
    AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();
    assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
    assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());
    for (int i = 0; i < cache.numWords(); i++) {
        assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
        assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
        assertEquals(cache.elementAtIndex(i).getElementFrequency(), restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
        List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
        List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
        assertEquals(originalPoints.size(), restoredPoints.size());
        for (int x = 0; x < originalPoints.size(); x++) {
            assertEquals(originalPoints.get(x), restoredPoints.get(x));
        }
        List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
        List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
        assertEquals(originalCodes.size(), restoredCodes.size());
        for (int x = 0; x < originalCodes.size(); x++) {
            assertEquals(originalCodes.get(x), restoredCodes.get(x));
        }
    }
}
Also used : VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) ParagraphVectors(org.deeplearning4j.models.paragraphvectors.ParagraphVectors) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) File(java.io.File) Test(org.junit.Test)

Example 4 with VectorsConfiguration

use of org.deeplearning4j.models.embeddings.loader.VectorsConfiguration in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testGensimEquality.

/**
     * Special test to check d2v inference against pre-trained gensim model and
     */
@Ignore
@Test
public void testGensimEquality() throws Exception {
    INDArray expA = Nd4j.create(new double[] { -0.02461922, -0.00801059, -0.01821643, 0.0167951, 0.02240154, -0.00414107, -0.0022868, 0.00278438, -0.00651088, -0.02066556, -0.01045411, -0.02853066, 0.00153375, 0.02707097, -0.00754221, -0.02795872, -0.00275301, -0.01455731, -0.00981289, 0.01557207, -0.005259, 0.00355505, 0.01503531, -0.02185878, 0.0339283, -0.05049067, 0.02849454, -0.01242505, 0.00438659, -0.03037345, 0.01866657, -0.00740161, -0.01850279, 0.00851284, -0.01774663, -0.01976997, -0.03317627, 0.00372983, 0.01313218, -0.00041131, 0.00089357, -0.0156924, 0.01278253, -0.01596088, -0.01415407, -0.01795845, 0.00558284, -0.00529536, -0.03508032, 0.00725479, -0.01910841, -0.0008098, 0.00614283, -0.00926585, 0.01761538, -0.00272953, -0.01483113, 0.02062481, -0.03134528, 0.03416841, -0.0156226, -0.01418961, -0.00817538, 0.01848741, 0.00444605, 0.01090323, 0.00746163, -0.02490317, 0.00835013, 0.01091823, -0.0177979, 0.0207753, -0.00854185, 0.04269911, 0.02786852, 0.00179449, 0.00303065, -0.00127148, -0.01589409, -0.01110292, 0.01736244, -0.01177608, 0.00110929, 0.01790557, -0.01800732, 0.00903072, 0.00210271, 0.0103053, -0.01508116, 0.00336775, 0.00319031, -0.00982859, 0.02409827, -0.0079536, 0.01347831, -0.02555985, 0.00282605, 0.00350526, -0.00471707, -0.00592073, -0.01009063, -0.02396305, 0.02643895, -0.05487461, -0.01710705, -0.0082839, 0.01322765, 0.00098093, 0.01707118, 0.00290805, 0.03256396, 0.00277155, 0.00350602, 0.0096487, -0.0062662, 0.0331796, -0.01758772, 0.0295204, 0.00295053, -0.00670782, 0.02172252, 0.00172433, 0.0122977, -0.02401575, 0.01179839, -0.01646545, -0.0242724, 0.01318037, -0.00745518, -0.00400624, -0.01735787, 0.01627645, 0.04445697, -0.0189355, 0.01315041, 0.0131585, 0.01770667, -0.00114554, 0.00581599, 0.00745188, -0.01318868, -0.00801476, -0.00884938, 0.00084786, 0.02578231, -0.01312729, -0.02047793, 0.00485749, -0.00342519, -0.00744475, 0.01180929, 0.02871456, 0.01483848, -0.00696516, 0.02003011, -0.01721076, -0.0124568, -0.0114492, -0.00970469, 0.01971609, 0.01599673, -0.01426137, 0.00808409, -0.01431519, 0.01187332, 0.00144421, -0.00459554, 0.00384032, 0.00866845, 0.00265177, -0.01003456, 0.0289338, 0.00353483, -0.01664903, -0.03050662, 0.01305057, -0.0084294, -0.01615093, -0.00897918, 0.00768479, 0.02155688, 0.01594496, 0.00034328, -0.00557031, -0.00256555, 0.03939554, 0.00274235, 0.001288, 0.02933025, 0.0070212, -0.00573742, 0.00883708, 0.00829396, -0.01100356, -0.02653269, -0.01023274, 0.03079773, -0.00765917, 0.00949703, 0.01212146, -0.01362515, -0.0076843, -0.00290596, -0.01707907, 0.02899382, -0.00089925, 0.01510732, 0.02378234, -0.00947305, 0.0010998, -0.00558241, 0.00057873, 0.01098226, -0.02019168, -0.013942, -0.01639287, -0.00675588, -0.00400709, -0.02914054, -0.00433462, 0.01551765, -0.03552055, 0.01681101, -0.00629782, -0.01698086, 0.01891401, 0.03597684, 0.00888052, -0.01587857, 0.00935822, 0.00931327, -0.0128156, 0.05170929, -0.01811879, 0.02096679, 0.00897546, 0.00132624, -0.01796336, 0.01888563, -0.01142226, -0.00805926, 0.00049782, -0.02151541, 0.00747257, 0.023373, -0.00198183, 0.02968843, 0.00443042, -0.00328569, -0.04200815, 0.01306543, -0.01608924, -0.01604842, 0.03137267, 0.0266054, 0.00172526, -0.01205696, 0.00047532, 0.00321026, 0.00671424, 0.01710422, -0.01129941, 0.00268044, -0.01065434, -0.01107133, 0.00036135, -0.02991677, 0.02351665, -0.00343891, -0.01736755, -0.00100577, -0.00312481, -0.01083809, 0.00387084, 0.01136449, 0.01675043, -0.01978249, -0.00765182, 0.02746241, -0.01082247, -0.01587164, 0.01104732, -0.00878782, -0.00497555, -0.00186257, -0.02281011, 0.00141792, 0.00432851, -0.01290263, -0.00387155, 0.00802639, -0.00761913, 0.01508144, 0.02226428, 0.0107248, 0.01003709, 0.01587571, 0.00083492, -0.01632052, -0.00435973 });
    INDArray expB = Nd4j.create(new double[] { -0.02465764, 0.00756337, -0.0268607, 0.01588023, 0.01580242, -0.00150542, 0.00116652, 0.0021577, -0.00754891, -0.02441176, -0.01271976, -0.02015191, 0.00220599, 0.03722657, -0.01629612, -0.02779619, -0.01157856, -0.01937938, -0.00744667, 0.01990043, -0.00505888, 0.00573646, 0.00385467, -0.0282531, 0.03484593, -0.05528606, 0.02428633, -0.01510474, 0.00153177, -0.03637344, 0.01747423, -0.00090738, -0.02199888, 0.01410434, -0.01710641, -0.01446697, -0.04225266, 0.00262217, 0.00871943, 0.00471594, 0.0101348, -0.01991908, 0.00874325, -0.00606416, -0.01035323, -0.01376545, 0.00451507, -0.01220307, -0.04361237, 0.00026028, -0.02401881, 0.00580314, 0.00238946, -0.01325974, 0.01879044, -0.00335623, -0.01631887, 0.02222102, -0.02998703, 0.03190075, -0.01675236, -0.01799807, -0.01314015, 0.01950069, 0.0011723, 0.01013178, 0.01093296, -0.034143, 0.00420227, 0.01449351, -0.00629987, 0.01652851, -0.01286825, 0.03314656, 0.03485073, 0.01120341, 0.01298241, 0.0019494, -0.02420256, -0.0063762, 0.01527091, -0.00732881, 0.0060427, 0.019327, -0.02068196, 0.00876712, 0.00292274, 0.01312969, -0.01529114, 0.0021757, -0.00565621, -0.01093122, 0.02758765, -0.01342688, 0.01606117, -0.02666447, 0.00541112, 0.00375426, -0.00761796, 0.00136015, -0.01169962, -0.03012749, 0.03012953, -0.05491332, -0.01137303, -0.01392103, 0.01370098, -0.00794501, 0.0248435, 0.00319645, 0.04261713, -0.00364211, 0.00780485, 0.01182583, -0.00647098, 0.03291231, -0.02515565, 0.03480943, 0.00119836, -0.00490694, 0.02615346, -0.00152456, 0.00196142, -0.02326461, 0.00603225, -0.02414703, -0.02540966, 0.0072112, -0.01090273, -0.00505061, -0.02196866, 0.00515245, 0.04981546, -0.02237269, -0.00189305, 0.0169786, 0.01782372, -0.00430022, 0.00551226, 0.00293861, -0.01337168, -0.00302476, -0.01869966, 0.00270757, 0.03199976, -0.01614617, -0.02716484, 0.01560035, -0.01312686, -0.01604082, 0.01347521, 0.03229654, 0.00707219, -0.00588392, 0.02444809, -0.01068742, -0.0190814, -0.00556385, -0.00462766, 0.01283929, 0.02001247, -0.00837629, -0.00041943, -0.02298774, 0.00874839, 0.00434907, -0.00963332, 0.00476905, 0.00793049, -0.00212557, -0.01839353, 0.03345517, 0.00838255, -0.0157447, -0.0376134, 0.01059611, -0.02323246, -0.01326356, -0.01116734, 0.00598869, 0.0211626, 0.01872963, -0.0038276, -0.01208279, -0.00989125, 0.04147648, 0.00181867, -0.00369355, 0.02312465, 0.0048396, 0.00564515, 0.01317832, -0.0057621, -0.01882041, -0.02869064, -0.00670661, 0.02585443, -0.01108428, 0.01411031, 0.01204507, -0.01244726, -0.00962342, -0.00205239, -0.01653971, 0.02871559, -0.00772978, 0.0214524, 0.02035478, -0.01324312, 0.00169302, -0.00064739, 0.00531795, 0.01059279, -0.02455794, -0.00002782, -0.0068906, -0.0160858, -0.0031842, -0.02295724, 0.01481094, 0.01769004, -0.02925742, 0.02050495, -0.00029003, -0.02815636, 0.02467367, 0.03419458, 0.00654938, -0.01847546, 0.00999932, 0.00059222, -0.01722176, 0.05172159, -0.01548486, 0.01746444, 0.007871, 0.0078471, -0.02414417, 0.01898077, -0.01470176, -0.00299465, 0.00368212, -0.02474656, 0.01317451, 0.03706085, -0.00032923, 0.02655881, 0.0013586, -0.0120303, -0.05030316, 0.0222294, -0.0070967, -0.02150935, 0.03254268, 0.01369857, 0.00246183, -0.02253576, -0.00551247, 0.00787363, 0.01215617, 0.02439827, -0.01104699, -0.00774596, -0.01898127, -0.01407653, 0.00195514, -0.03466602, 0.01560903, -0.01239944, -0.02474852, 0.00155114, 0.00089324, -0.01725949, -0.00011816, 0.00742845, 0.01247074, -0.02467943, -0.00679623, 0.01988366, -0.00626181, -0.02396477, 0.01052101, -0.01123178, -0.00386291, -0.00349261, -0.02714747, -0.00563315, 0.00228767, -0.01303677, -0.01971108, 0.00014759, -0.00346399, 0.02220698, 0.01979946, -0.00526076, 0.00647453, 0.01428513, 0.00223467, -0.01690172, -0.0081715 });
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(5);
    configuration.setLearningRate(0.01);
    configuration.setUseHierarchicSoftmax(true);
    configuration.setNegative(0);
    Word2Vec w2v = WordVectorSerializer.readWord2VecFromText(new File("/home/raver119/Downloads/gensim_models_for_dl4j/word"), new File("/home/raver119/Downloads/gensim_models_for_dl4j/hs"), new File("/home/raver119/Downloads/gensim_models_for_dl4j/hs_code"), new File("/home/raver119/Downloads/gensim_models_for_dl4j/hs_mapping"), configuration);
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
    assertNotEquals(null, w2v.getLookupTable());
    assertNotEquals(null, w2v.getVocab());
    ParagraphVectors d2v = new ParagraphVectors.Builder(configuration).useExistingWordVectors(w2v).sequenceLearningAlgorithm(new DM<VocabWord>()).tokenizerFactory(tokenizerFactory).resetModel(false).build();
    assertNotEquals(null, d2v.getLookupTable());
    assertNotEquals(null, d2v.getVocab());
    assertTrue(d2v.getVocab() == w2v.getVocab());
    assertTrue(d2v.getLookupTable() == w2v.getLookupTable());
    String textA = "Donald Trump referred to President Obama as “your president” during the first presidential debate on Monday, much to many people’s chagrin on social media. Trump, made the reference after saying that the greatest threat facing the world is nuclear weapons. He then turned to Hillary Clinton and said, “Not global warming like you think and your President thinks,” referring to Obama.";
    String textB = "The comment followed Trump doubling down on his false claims about the so-called birther conspiracy theory about Obama. People following the debate were immediately angered that Trump implied Obama is not his president.";
    String textC = "practice of trust owned Trump for example indeed and conspiracy between provoke";
    INDArray arrayA = d2v.inferVector(textA);
    INDArray arrayB = d2v.inferVector(textB);
    INDArray arrayC = d2v.inferVector(textC);
    assertNotEquals(null, arrayA);
    assertNotEquals(null, arrayB);
    Transforms.unitVec(arrayA);
    Transforms.unitVec(arrayB);
    Transforms.unitVec(expA);
    Transforms.unitVec(expB);
    double simX = Transforms.cosineSim(arrayA, arrayB);
    double simC = Transforms.cosineSim(arrayA, arrayC);
    double simB = Transforms.cosineSim(arrayB, expB);
    log.info("SimilarityX: {}", simX);
    log.info("SimilarityC: {}", simC);
    log.info("SimilarityB: {}", simB);
}
Also used : DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) DM(org.deeplearning4j.models.embeddings.learning.impl.sequence.DM) File(java.io.File) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 5 with VectorsConfiguration

use of org.deeplearning4j.models.embeddings.loader.VectorsConfiguration in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method testAbstractW2VModel.

@Test
public void testAbstractW2VModel() throws Exception {
    ClassPathResource resource = new ClassPathResource("big/raw_sentences.txt");
    File file = resource.getFile();
    logger.info("dtype: {}", Nd4j.dataType());
    AbstractCache<VocabWord> vocabCache = new AbstractCache.Builder<VocabWord>().build();
    /*
            First we build line iterator
         */
    BasicLineIterator underlyingIterator = new BasicLineIterator(file);
    /*
            Now we need the way to convert lines into Sequences of VocabWords.
            In this example that's SentenceTransformer
         */
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    SentenceTransformer transformer = new SentenceTransformer.Builder().iterator(underlyingIterator).tokenizerFactory(t).build();
    /*
            And we pack that transformer into AbstractSequenceIterator
         */
    AbstractSequenceIterator<VocabWord> sequenceIterator = new AbstractSequenceIterator.Builder<>(transformer).build();
    /*
            Now we should build vocabulary out of sequence iterator.
            We can skip this phase, and just set SequenceVectors.resetModel(TRUE), and vocabulary will be mastered internally
        */
    VocabConstructor<VocabWord> constructor = new VocabConstructor.Builder<VocabWord>().addSource(sequenceIterator, 5).setTargetVocabCache(vocabCache).build();
    constructor.buildJointVocabulary(false, true);
    assertEquals(242, vocabCache.numWords());
    assertEquals(634303, vocabCache.totalWordOccurrences());
    VocabWord wordz = vocabCache.wordFor("day");
    logger.info("Wordz: " + wordz);
    /*
            Time to build WeightLookupTable instance for our new model
        */
    WeightLookupTable<VocabWord> lookupTable = new InMemoryLookupTable.Builder<VocabWord>().lr(0.025).vectorLength(150).useAdaGrad(false).cache(vocabCache).build();
    /*
            reset model is viable only if you're setting SequenceVectors.resetModel() to false
            if set to True - it will be called internally
        */
    lookupTable.resetWeights(true);
    /*
            Now we can build SequenceVectors model, that suits our needs
         */
    SequenceVectors<VocabWord> vectors = new SequenceVectors.Builder<VocabWord>(new VectorsConfiguration()).minWordFrequency(5).lookupTable(lookupTable).iterate(sequenceIterator).vocabCache(vocabCache).batchSize(250).iterations(1).epochs(1).resetModel(false).trainElementsRepresentation(true).trainSequencesRepresentation(false).build();
    /*
            Now, after all options are set, we just call fit()
         */
    logger.info("Starting training...");
    vectors.fit();
    logger.info("Model saved...");
    /*
            As soon as fit() exits, model considered built, and we can test it.
            Please note: all similarity context is handled via SequenceElement's labels, so if you're using SequenceVectors to build models for complex
            objects/relations please take care of Labels uniqueness and meaning for yourself.
         */
    double sim = vectors.similarity("day", "night");
    logger.info("Day/night similarity: " + sim);
    assertTrue(sim > 0.6d);
    Collection<String> labels = vectors.wordsNearest("day", 10);
    logger.info("Nearest labels to 'day': " + labels);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) VectorsConfiguration(org.deeplearning4j.models.embeddings.loader.VectorsConfiguration) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) VocabConstructor(org.deeplearning4j.models.word2vec.wordstore.VocabConstructor) SentenceTransformer(org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer) ClassPathResource(org.datavec.api.util.ClassPathResource) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) AbstractSequenceIterator(org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator) File(java.io.File) Test(org.junit.Test)

Aggregations

VectorsConfiguration (org.deeplearning4j.models.embeddings.loader.VectorsConfiguration)6 Test (org.junit.Test)6 File (java.io.File)5 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)4 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)4 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)4 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)4 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)4 ClassPathResource (org.datavec.api.util.ClassPathResource)3 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)3 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)3 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)3 Ignore (org.junit.Ignore)3 InMemoryLookupTable (org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 GloVe (org.deeplearning4j.models.embeddings.learning.impl.elements.GloVe)1 DM (org.deeplearning4j.models.embeddings.learning.impl.sequence.DM)1 FlatModelUtils (org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils)1 ParagraphVectors (org.deeplearning4j.models.paragraphvectors.ParagraphVectors)1