Search in sources :

Example 46 with ClassPathResource

use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsVocabBuilding1.

/*
    @Test
    public void testWord2VecRunThroughVectors() throws Exception {
        ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
        File file = resource.getFile().getParentFile();
        LabelAwareSentenceIterator iter = LabelAwareUimaSentenceIterator.createWithPath(file.getAbsolutePath());
    
    
        TokenizerFactory t = new UimaTokenizerFactory();
    
    
        ParagraphVectors vec = new ParagraphVectors.Builder()
                .minWordFrequency(1).iterations(5).labels(Arrays.asList("label1", "deeple"))
                .layerSize(100)
                .stopWords(new ArrayList<String>())
                .windowSize(5).iterate(iter).tokenizerFactory(t).build();
    
        assertEquals(new ArrayList<String>(), vec.getStopWords());
    
    
        vec.fit();
        double sim = vec.similarity("day","night");
        log.info("day/night similarity: " + sim);
        new File("cache.ser").delete();
    
    }
    */
/**
     * This test checks, how vocab is built using SentenceIterator provided, without labels.
     *
     * @throws Exception
     */
@Test
public void testParagraphVectorsVocabBuilding1() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    //.getParentFile();
    File file = resource.getFile();
    //UimaSentenceIterator.createWithPath(file.getAbsolutePath());
    SentenceIterator iter = new BasicLineIterator(file);
    int numberOfLines = 0;
    while (iter.hasNext()) {
        iter.nextSentence();
        numberOfLines++;
    }
    iter.reset();
    InMemoryLookupCache cache = new InMemoryLookupCache(false);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    // LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(5).layerSize(100).windowSize(5).iterate(iter).vocabCache(cache).tokenizerFactory(t).build();
    vec.buildVocab();
    LabelsSource source = vec.getLabelsSource();
    //VocabCache cache = vec.getVocab();
    log.info("Number of lines in corpus: " + numberOfLines);
    assertEquals(numberOfLines, source.getLabels().size());
    assertEquals(97162, source.getLabels().size());
    assertNotEquals(null, cache);
    assertEquals(97406, cache.numWords());
    // proper number of words for minWordsFrequency = 1 is 244
    assertEquals(244, cache.numWords() - source.getLabels().size());
}
Also used : DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) InMemoryLookupCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache) Test(org.junit.Test)

Example 47 with ClassPathResource

use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsModelling1.

/**
     * This test doesn't really cares about actual results. We only care about equality between live model & restored models
     *
     * @throws Exception
     */
@Test
public void testParagraphVectorsModelling1() throws Exception {
    ClassPathResource resource = new ClassPathResource("/big/raw_sentences.txt");
    File file = resource.getFile();
    SentenceIterator iter = new BasicLineIterator(file);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    LabelsSource source = new LabelsSource("DOC_");
    ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(5).seed(119).epochs(1).layerSize(150).learningRate(0.025).labelsSource(source).windowSize(5).sequenceLearningAlgorithm(new DM<VocabWord>()).iterate(iter).trainWordVectors(true).tokenizerFactory(t).workers(4).sampling(0).build();
    vec.fit();
    VocabCache<VocabWord> cache = vec.getVocab();
    File fullFile = File.createTempFile("paravec", "tests");
    fullFile.deleteOnExit();
    INDArray originalSyn1_17 = ((InMemoryLookupTable) vec.getLookupTable()).getSyn1().getRow(17).dup();
    WordVectorSerializer.writeParagraphVectors(vec, fullFile);
    int cnt1 = cache.wordFrequency("day");
    int cnt2 = cache.wordFrequency("me");
    assertNotEquals(1, cnt1);
    assertNotEquals(1, cnt2);
    assertNotEquals(cnt1, cnt2);
    assertEquals(97406, cache.numWords());
    assertTrue(vec.hasWord("DOC_16392"));
    assertTrue(vec.hasWord("DOC_3720"));
    List<String> result = new ArrayList<>(vec.nearestLabels(vec.getWordVectorMatrix("DOC_16392"), 10));
    System.out.println("nearest labels: " + result);
    for (String label : result) {
        System.out.println(label + "/DOC_16392: " + vec.similarity(label, "DOC_16392"));
    }
    assertTrue(result.contains("DOC_16392"));
    //assertTrue(result.contains("DOC_21383"));
    /*
            We have few lines that contain pretty close words invloved.
            These sentences should be pretty close to each other in vector space
         */
    // line 3721: This is my way .
    // line 6348: This is my case .
    // line 9836: This is my house .
    // line 12493: This is my world .
    // line 16393: This is my work .
    // this is special sentence, that has nothing common with previous sentences
    // line 9853: We now have one .
    double similarityD = vec.similarity("day", "night");
    log.info("day/night similarity: " + similarityD);
    if (similarityD < 0.0) {
        log.info("Day: " + Arrays.toString(vec.getWordVectorMatrix("day").dup().data().asDouble()));
        log.info("Night: " + Arrays.toString(vec.getWordVectorMatrix("night").dup().data().asDouble()));
    }
    List<String> labelsOriginal = vec.labelsSource.getLabels();
    double similarityW = vec.similarity("way", "work");
    log.info("way/work similarity: " + similarityW);
    double similarityH = vec.similarity("house", "world");
    log.info("house/world similarity: " + similarityH);
    double similarityC = vec.similarity("case", "way");
    log.info("case/way similarity: " + similarityC);
    double similarity1 = vec.similarity("DOC_9835", "DOC_12492");
    log.info("9835/12492 similarity: " + similarity1);
    //        assertTrue(similarity1 > 0.7d);
    double similarity2 = vec.similarity("DOC_3720", "DOC_16392");
    log.info("3720/16392 similarity: " + similarity2);
    //        assertTrue(similarity2 > 0.7d);
    double similarity3 = vec.similarity("DOC_6347", "DOC_3720");
    log.info("6347/3720 similarity: " + similarity3);
    //        assertTrue(similarity2 > 0.7d);
    // likelihood in this case should be significantly lower
    double similarityX = vec.similarity("DOC_3720", "DOC_9852");
    log.info("3720/9852 similarity: " + similarityX);
    assertTrue(similarityX < 0.5d);
    File tempFile = File.createTempFile("paravec", "ser");
    tempFile.deleteOnExit();
    INDArray day = vec.getWordVectorMatrix("day").dup();
    /*
            Testing txt serialization
         */
    File tempFile2 = File.createTempFile("paravec", "ser");
    tempFile2.deleteOnExit();
    WordVectorSerializer.writeWordVectors(vec, tempFile2);
    ParagraphVectors vec3 = WordVectorSerializer.readParagraphVectorsFromText(tempFile2);
    INDArray day3 = vec3.getWordVectorMatrix("day").dup();
    List<String> labelsRestored = vec3.labelsSource.getLabels();
    assertEquals(day, day3);
    assertEquals(labelsOriginal.size(), labelsRestored.size());
    /*
         Testing binary serialization
        */
    SerializationUtils.saveObject(vec, tempFile);
    ParagraphVectors vec2 = (ParagraphVectors) SerializationUtils.readObject(tempFile);
    INDArray day2 = vec2.getWordVectorMatrix("day").dup();
    List<String> labelsBinary = vec2.labelsSource.getLabels();
    assertEquals(day, day2);
    tempFile.delete();
    assertEquals(labelsOriginal.size(), labelsBinary.size());
    INDArray original = vec.getWordVectorMatrix("DOC_16392").dup();
    INDArray originalPreserved = original.dup();
    INDArray inferredA1 = vec.inferVector("This is my work .");
    INDArray inferredB1 = vec.inferVector("This is my work .");
    double cosAO1 = Transforms.cosineSim(inferredA1.dup(), original.dup());
    double cosAB1 = Transforms.cosineSim(inferredA1.dup(), inferredB1.dup());
    log.info("Cos O/A: {}", cosAO1);
    log.info("Cos A/B: {}", cosAB1);
    //        assertTrue(cosAO1 > 0.45);
    assertTrue(cosAB1 > 0.95);
    //assertArrayEquals(inferredA.data().asDouble(), inferredB.data().asDouble(), 0.01);
    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(fullFile);
    restoredVectors.setTokenizerFactory(t);
    INDArray restoredSyn1_17 = ((InMemoryLookupTable) restoredVectors.getLookupTable()).getSyn1().getRow(17).dup();
    assertEquals(originalSyn1_17, restoredSyn1_17);
    INDArray originalRestored = vec.getWordVectorMatrix("DOC_16392").dup();
    assertEquals(originalPreserved, originalRestored);
    INDArray inferredA2 = restoredVectors.inferVector("This is my work .");
    INDArray inferredB2 = restoredVectors.inferVector("This is my work .");
    INDArray inferredC2 = restoredVectors.inferVector("world way case .");
    double cosAO2 = Transforms.cosineSim(inferredA2.dup(), original.dup());
    double cosAB2 = Transforms.cosineSim(inferredA2.dup(), inferredB2.dup());
    double cosAAX = Transforms.cosineSim(inferredA1.dup(), inferredA2.dup());
    double cosAC2 = Transforms.cosineSim(inferredC2.dup(), inferredA2.dup());
    log.info("Cos A2/B2: {}", cosAB2);
    log.info("Cos A1/A2: {}", cosAAX);
    log.info("Cos O/A2: {}", cosAO2);
    log.info("Cos C2/A2: {}", cosAC2);
    log.info("Vector: {}", Arrays.toString(inferredA1.data().asFloat()));
    log.info("cosAO2: {}", cosAO2);
    //  assertTrue(cosAO2 > 0.45);
    assertTrue(cosAB2 > 0.95);
    assertTrue(cosAAX > 0.95);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelsSource(org.deeplearning4j.text.documentiterator.LabelsSource) File(java.io.File) Test(org.junit.Test)

Example 48 with ClassPathResource

use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.

the class ParagraphVectorsTest method testParagraphVectorsOverExistingWordVectorsModel.

/*
        In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
        there's no need in this test within travis, use it manually only for problems detection
    */
@Test
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
    // we build w2v from multiple sources, to cover everything
    ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
    ClassPathResource resource_mixed = new ClassPathResource("/paravec");
    SentenceIterator iter = new AggregatingSentenceIterator.Builder().addSentenceIterator(new BasicLineIterator(resource_sentences.getFile())).addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3).learningRate(0.025).layerSize(150).minLearningRate(0.001).elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5).iterate(iter).tokenizerFactory(t).build();
    wordVectors.fit();
    VocabWord day_A = wordVectors.getVocab().tokenFor("day");
    INDArray vector_day1 = wordVectors.getWordVectorMatrix("day").dup();
    // At this moment we have ready w2v model. It's time to use it for ParagraphVectors
    FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
    // documents from this iterator will be used for classification
    FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/unlabeled").getFile()).build();
    // we're building classifier now, with pre-built w2v model passed in
    ParagraphVectors paragraphVectors = new ParagraphVectors.Builder().iterate(labelAwareIterator).learningRate(0.025).minLearningRate(0.001).iterations(5).epochs(1).layerSize(150).tokenizerFactory(t).sequenceLearningAlgorithm(new DBOW<VocabWord>()).useHierarchicSoftmax(true).trainWordVectors(false).useExistingWordVectors(wordVectors).build();
    paragraphVectors.fit();
    VocabWord day_B = paragraphVectors.getVocab().tokenFor("day");
    assertEquals(day_A.getIndex(), day_B.getIndex());
    /*
        double similarityD = wordVectors.similarity("day", "night");
        log.info("day/night similarity: " + similarityD);
        assertTrue(similarityD > 0.5d);
        */
    INDArray vector_day2 = paragraphVectors.getWordVectorMatrix("day").dup();
    double crossDay = arraysSimilarity(vector_day1, vector_day2);
    log.info("Day1: " + vector_day1);
    log.info("Day2: " + vector_day2);
    log.info("Cross-Day similarity: " + crossDay);
    log.info("Cross-Day similiarity 2: " + Transforms.cosineSim(vector_day1, vector_day2));
    assertTrue(crossDay > 0.9d);
    /**
         *
         * Here we're checking cross-vocabulary equality
         *
         */
    /*
        Random rnd = new Random();
        VocabCache<VocabWord> cacheP = paragraphVectors.getVocab();
        VocabCache<VocabWord> cacheW = wordVectors.getVocab();
        for (int x = 0; x < 1000; x++) {
            int idx = rnd.nextInt(cacheW.numWords());
        
            String wordW = cacheW.wordAtIndex(idx);
            String wordP = cacheP.wordAtIndex(idx);
        
            assertEquals(wordW, wordP);
        
            INDArray arrayW = wordVectors.getWordVectorMatrix(wordW);
            INDArray arrayP = paragraphVectors.getWordVectorMatrix(wordP);
        
            double simWP = Transforms.cosineSim(arrayW, arrayP);
            assertTrue(simWP >= 0.9);
        }
        */
    log.info("Zfinance: " + paragraphVectors.getWordVectorMatrix("Zfinance"));
    log.info("Zhealth: " + paragraphVectors.getWordVectorMatrix("Zhealth"));
    log.info("Zscience: " + paragraphVectors.getWordVectorMatrix("Zscience"));
    LabelledDocument document = unlabeledIterator.nextDocument();
    log.info("Results for document '" + document.getLabel() + "'");
    List<String> results = new ArrayList<>(paragraphVectors.predictSeveral(document, 3));
    for (String result : results) {
        double sim = paragraphVectors.similarityToLabel(document, result);
        log.info("Similarity to [" + result + "] is [" + sim + "]");
    }
    String topPrediction = paragraphVectors.predict(document);
    assertEquals("Zfinance", topPrediction);
}
Also used : BasicLineIterator(org.deeplearning4j.text.sentenceiterator.BasicLineIterator) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) SkipGram(org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram) FileLabelAwareIterator(org.deeplearning4j.text.documentiterator.FileLabelAwareIterator) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) SentenceIterator(org.deeplearning4j.text.sentenceiterator.SentenceIterator) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) LabelledDocument(org.deeplearning4j.text.documentiterator.LabelledDocument) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) CommonPreprocessor(org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor) AggregatingSentenceIterator(org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) FileSentenceIterator(org.deeplearning4j.text.sentenceiterator.FileSentenceIterator) Test(org.junit.Test)

Example 49 with ClassPathResource

use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.

the class BagOfWordsVectorizerTest method testBagOfWordsVectorizer.

@Test
public void testBagOfWordsVectorizer() throws Exception {
    File rootDir = new ClassPathResource("rootdir").getFile();
    LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
    List<String> labels = Arrays.asList("label1", "label2");
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    BagOfWordsVectorizer vectorizer = new BagOfWordsVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
    vectorizer.fit();
    VocabWord word = vectorizer.getVocabCache().wordFor("file.");
    assumeNotNull(word);
    assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
    assertEquals(2, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(2, word.getSequencesCount());
    assertEquals(2, word.getElementFrequency(), 0.1);
    VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
    assertEquals(1, word1.getSequencesCount());
    assertEquals(1, word1.getElementFrequency(), 0.1);
    log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
    assertEquals(2, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
    ///////////////////
    INDArray array = vectorizer.transform("This is 2 file.");
    log.info("Transformed array: " + array);
    assertEquals(5, array.columns());
    VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
    assertEquals(2, array.getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
    assertEquals(2, array.getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
    assertEquals(2, array.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
    assertEquals(0, array.getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
    assertEquals(1, array.getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
    DataSet dataSet = vectorizer.vectorize("This is 2 file.", "label2");
    assertEquals(array, dataSet.getFeatureMatrix());
    INDArray labelz = dataSet.getLabels();
    log.info("Labels array: " + labelz);
    int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult();
    //        assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
    //        assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
    dataSet = vectorizer.vectorize("This is 1 file.", "label1");
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("This").getIndex()), 0.1);
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("is").getIndex()), 0.1);
    assertEquals(2, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("file.").getIndex()), 0.1);
    assertEquals(1, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
    assertEquals(0, dataSet.getFeatureMatrix().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
    int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult();
    //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
    //assertEquals(1.0, dataSet.getLabels().getDouble(1), 0.1);
    assertNotEquals(idx2, idx1);
    // Serialization check
    File tempFile = File.createTempFile("fdsf", "fdfsdf");
    tempFile.deleteOnExit();
    SerializationUtils.saveObject(vectorizer, tempFile);
    BagOfWordsVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
    vectorizer2.setTokenizerFactory(tokenizerFactory);
    dataSet = vectorizer2.vectorize("This is 2 file.", "label2");
    assertEquals(array, dataSet.getFeatureMatrix());
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IMax(org.nd4j.linalg.api.ops.impl.indexaccum.IMax) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) File(java.io.File) IndexAccumulation(org.nd4j.linalg.api.ops.IndexAccumulation) Test(org.junit.Test)

Example 50 with ClassPathResource

use of org.datavec.api.util.ClassPathResource in project deeplearning4j by deeplearning4j.

the class TfidfVectorizerTest method testTfIdfVectorizer.

@Test
public void testTfIdfVectorizer() throws Exception {
    File rootDir = new ClassPathResource("tripledir").getFile();
    LabelAwareSentenceIterator iter = new LabelAwareFileSentenceIterator(rootDir);
    TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
    TfidfVectorizer vectorizer = new TfidfVectorizer.Builder().setMinWordFrequency(1).setStopWords(new ArrayList<String>()).setTokenizerFactory(tokenizerFactory).setIterator(iter).allowParallelTokenization(false).build();
    vectorizer.fit();
    VocabWord word = vectorizer.getVocabCache().wordFor("file.");
    assumeNotNull(word);
    assertEquals(word, vectorizer.getVocabCache().tokenFor("file."));
    assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(3, word.getSequencesCount());
    assertEquals(3, word.getElementFrequency(), 0.1);
    VocabWord word1 = vectorizer.getVocabCache().wordFor("1");
    assertEquals(1, word1.getSequencesCount());
    assertEquals(1, word1.getElementFrequency(), 0.1);
    log.info("Labels used: " + vectorizer.getLabelsSource().getLabels());
    assertEquals(3, vectorizer.getLabelsSource().getNumberOfLabelsUsed());
    assertEquals(3, vectorizer.getVocabCache().totalNumberOfDocs());
    assertEquals(11, vectorizer.numWordsEncountered());
    INDArray vector = vectorizer.transform("This is 3 file.");
    log.info("TF-IDF vector: " + Arrays.toString(vector.data().asDouble()));
    VocabCache<VocabWord> vocabCache = vectorizer.getVocabCache();
    assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("This").getIndex()), 0.001);
    assertEquals(.04402, vector.getDouble(vocabCache.tokenFor("is").getIndex()), 0.001);
    assertEquals(0.119, vector.getDouble(vocabCache.tokenFor("3").getIndex()), 0.001);
    assertEquals(0, vector.getDouble(vocabCache.tokenFor("file.").getIndex()), 0.001);
    DataSet dataSet = vectorizer.vectorize("This is 3 file.", "label3");
    //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
    //assertEquals(0.0, dataSet.getLabels().getDouble(1), 0.1);
    //assertEquals(1.0, dataSet.getLabels().getDouble(2), 0.1);
    int cnt = 0;
    for (int i = 0; i < 3; i++) {
        if (dataSet.getLabels().getDouble(i) > 0.1)
            cnt++;
    }
    assertEquals(1, cnt);
    File tempFile = File.createTempFile("somefile", "Dsdas");
    tempFile.deleteOnExit();
    SerializationUtils.saveObject(vectorizer, tempFile);
    TfidfVectorizer vectorizer2 = SerializationUtils.readObject(tempFile);
    vectorizer2.setTokenizerFactory(tokenizerFactory);
    dataSet = vectorizer2.vectorize("This is 3 file.", "label2");
    assertEquals(vector, dataSet.getFeatureMatrix());
}
Also used : TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) DataSet(org.nd4j.linalg.dataset.DataSet) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ClassPathResource(org.datavec.api.util.ClassPathResource) LabelAwareFileSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator) DefaultTokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LabelAwareSentenceIterator(org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator) File(java.io.File) Test(org.junit.Test)

Aggregations

ClassPathResource (org.datavec.api.util.ClassPathResource)71 Test (org.junit.Test)63 File (java.io.File)45 TokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory)28 BasicLineIterator (org.deeplearning4j.text.sentenceiterator.BasicLineIterator)27 DefaultTokenizerFactory (org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory)27 VocabWord (org.deeplearning4j.models.word2vec.VocabWord)23 SentenceIterator (org.deeplearning4j.text.sentenceiterator.SentenceIterator)23 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CommonPreprocessor (org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor)20 SentenceTransformer (org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer)12 AbstractCache (org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache)11 WordVectors (org.deeplearning4j.models.embeddings.wordvectors.WordVectors)10 AbstractSequenceIterator (org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator)10 ArrayList (java.util.ArrayList)9 Word2Vec (org.deeplearning4j.models.word2vec.Word2Vec)8 AggregatingSentenceIterator (org.deeplearning4j.text.sentenceiterator.AggregatingSentenceIterator)7 FileSentenceIterator (org.deeplearning4j.text.sentenceiterator.FileSentenceIterator)7 DataSet (org.nd4j.linalg.dataset.DataSet)7 InputStream (java.io.InputStream)6