use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testDirectInference.
@Test
public void testDirectInference() throws Exception {
ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
ClassPathResource resource_mixed = new ClassPathResource("/paravec");
SentenceIterator iter = new AggregatingSentenceIterator.Builder().addSentenceIterator(new BasicLineIterator(resource_sentences.getFile())).addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3).learningRate(0.025).layerSize(150).minLearningRate(0.001).elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5).iterate(iter).tokenizerFactory(t).build();
wordVectors.fit();
ParagraphVectors pv = new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10).useHierarchicSoftmax(true).trainWordVectors(true).useExistingWordVectors(wordVectors).negativeSample(0).sequenceLearningAlgorithm(new DM<VocabWord>()).build();
INDArray vec1 = pv.inferVector("This text is pretty awesome");
INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");
log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testParagraphVectorsOverExistingWordVectorsModel.
/*
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
there's no need in this test within travis, use it manually only for problems detection
*/
@Test
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
// we build w2v from multiple sources, to cover everything
ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
ClassPathResource resource_mixed = new ClassPathResource("/paravec");
SentenceIterator iter = new AggregatingSentenceIterator.Builder().addSentenceIterator(new BasicLineIterator(resource_sentences.getFile())).addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3).learningRate(0.025).layerSize(150).minLearningRate(0.001).elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5).iterate(iter).tokenizerFactory(t).build();
wordVectors.fit();
VocabWord day_A = wordVectors.getVocab().tokenFor("day");
INDArray vector_day1 = wordVectors.getWordVectorMatrix("day").dup();
// At this moment we have ready w2v model. It's time to use it for ParagraphVectors
FileLabelAwareIterator labelAwareIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/labeled").getFile()).build();
// documents from this iterator will be used for classification
FileLabelAwareIterator unlabeledIterator = new FileLabelAwareIterator.Builder().addSourceFolder(new ClassPathResource("/paravec/unlabeled").getFile()).build();
// we're building classifier now, with pre-built w2v model passed in
ParagraphVectors paragraphVectors = new ParagraphVectors.Builder().iterate(labelAwareIterator).learningRate(0.025).minLearningRate(0.001).iterations(5).epochs(1).layerSize(150).tokenizerFactory(t).sequenceLearningAlgorithm(new DBOW<VocabWord>()).useHierarchicSoftmax(true).trainWordVectors(false).useExistingWordVectors(wordVectors).build();
paragraphVectors.fit();
VocabWord day_B = paragraphVectors.getVocab().tokenFor("day");
assertEquals(day_A.getIndex(), day_B.getIndex());
/*
double similarityD = wordVectors.similarity("day", "night");
log.info("day/night similarity: " + similarityD);
assertTrue(similarityD > 0.5d);
*/
INDArray vector_day2 = paragraphVectors.getWordVectorMatrix("day").dup();
double crossDay = arraysSimilarity(vector_day1, vector_day2);
log.info("Day1: " + vector_day1);
log.info("Day2: " + vector_day2);
log.info("Cross-Day similarity: " + crossDay);
log.info("Cross-Day similiarity 2: " + Transforms.cosineSim(vector_day1, vector_day2));
assertTrue(crossDay > 0.9d);
/**
*
* Here we're checking cross-vocabulary equality
*
*/
/*
Random rnd = new Random();
VocabCache<VocabWord> cacheP = paragraphVectors.getVocab();
VocabCache<VocabWord> cacheW = wordVectors.getVocab();
for (int x = 0; x < 1000; x++) {
int idx = rnd.nextInt(cacheW.numWords());
String wordW = cacheW.wordAtIndex(idx);
String wordP = cacheP.wordAtIndex(idx);
assertEquals(wordW, wordP);
INDArray arrayW = wordVectors.getWordVectorMatrix(wordW);
INDArray arrayP = paragraphVectors.getWordVectorMatrix(wordP);
double simWP = Transforms.cosineSim(arrayW, arrayP);
assertTrue(simWP >= 0.9);
}
*/
log.info("Zfinance: " + paragraphVectors.getWordVectorMatrix("Zfinance"));
log.info("Zhealth: " + paragraphVectors.getWordVectorMatrix("Zhealth"));
log.info("Zscience: " + paragraphVectors.getWordVectorMatrix("Zscience"));
LabelledDocument document = unlabeledIterator.nextDocument();
log.info("Results for document '" + document.getLabel() + "'");
List<String> results = new ArrayList<>(paragraphVectors.predictSeveral(document, 3));
for (String result : results) {
double sim = paragraphVectors.similarityToLabel(document, result);
log.info("Similarity to [" + result + "] is [" + sim + "]");
}
String topPrediction = paragraphVectors.predict(document);
assertEquals("Zfinance", topPrediction);
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class ParagraphVectorsTest method testGensimEquality.
/**
* Special test to check d2v inference against pre-trained gensim model and
*/
@Ignore
@Test
public void testGensimEquality() throws Exception {
INDArray expA = Nd4j.create(new double[] { -0.02461922, -0.00801059, -0.01821643, 0.0167951, 0.02240154, -0.00414107, -0.0022868, 0.00278438, -0.00651088, -0.02066556, -0.01045411, -0.02853066, 0.00153375, 0.02707097, -0.00754221, -0.02795872, -0.00275301, -0.01455731, -0.00981289, 0.01557207, -0.005259, 0.00355505, 0.01503531, -0.02185878, 0.0339283, -0.05049067, 0.02849454, -0.01242505, 0.00438659, -0.03037345, 0.01866657, -0.00740161, -0.01850279, 0.00851284, -0.01774663, -0.01976997, -0.03317627, 0.00372983, 0.01313218, -0.00041131, 0.00089357, -0.0156924, 0.01278253, -0.01596088, -0.01415407, -0.01795845, 0.00558284, -0.00529536, -0.03508032, 0.00725479, -0.01910841, -0.0008098, 0.00614283, -0.00926585, 0.01761538, -0.00272953, -0.01483113, 0.02062481, -0.03134528, 0.03416841, -0.0156226, -0.01418961, -0.00817538, 0.01848741, 0.00444605, 0.01090323, 0.00746163, -0.02490317, 0.00835013, 0.01091823, -0.0177979, 0.0207753, -0.00854185, 0.04269911, 0.02786852, 0.00179449, 0.00303065, -0.00127148, -0.01589409, -0.01110292, 0.01736244, -0.01177608, 0.00110929, 0.01790557, -0.01800732, 0.00903072, 0.00210271, 0.0103053, -0.01508116, 0.00336775, 0.00319031, -0.00982859, 0.02409827, -0.0079536, 0.01347831, -0.02555985, 0.00282605, 0.00350526, -0.00471707, -0.00592073, -0.01009063, -0.02396305, 0.02643895, -0.05487461, -0.01710705, -0.0082839, 0.01322765, 0.00098093, 0.01707118, 0.00290805, 0.03256396, 0.00277155, 0.00350602, 0.0096487, -0.0062662, 0.0331796, -0.01758772, 0.0295204, 0.00295053, -0.00670782, 0.02172252, 0.00172433, 0.0122977, -0.02401575, 0.01179839, -0.01646545, -0.0242724, 0.01318037, -0.00745518, -0.00400624, -0.01735787, 0.01627645, 0.04445697, -0.0189355, 0.01315041, 0.0131585, 0.01770667, -0.00114554, 0.00581599, 0.00745188, -0.01318868, -0.00801476, -0.00884938, 0.00084786, 0.02578231, -0.01312729, -0.02047793, 0.00485749, -0.00342519, -0.00744475, 0.01180929, 0.02871456, 0.01483848, -0.00696516, 0.02003011, -0.01721076, -0.0124568, -0.0114492, -0.00970469, 0.01971609, 0.01599673, -0.01426137, 0.00808409, -0.01431519, 0.01187332, 0.00144421, -0.00459554, 0.00384032, 0.00866845, 0.00265177, -0.01003456, 0.0289338, 0.00353483, -0.01664903, -0.03050662, 0.01305057, -0.0084294, -0.01615093, -0.00897918, 0.00768479, 0.02155688, 0.01594496, 0.00034328, -0.00557031, -0.00256555, 0.03939554, 0.00274235, 0.001288, 0.02933025, 0.0070212, -0.00573742, 0.00883708, 0.00829396, -0.01100356, -0.02653269, -0.01023274, 0.03079773, -0.00765917, 0.00949703, 0.01212146, -0.01362515, -0.0076843, -0.00290596, -0.01707907, 0.02899382, -0.00089925, 0.01510732, 0.02378234, -0.00947305, 0.0010998, -0.00558241, 0.00057873, 0.01098226, -0.02019168, -0.013942, -0.01639287, -0.00675588, -0.00400709, -0.02914054, -0.00433462, 0.01551765, -0.03552055, 0.01681101, -0.00629782, -0.01698086, 0.01891401, 0.03597684, 0.00888052, -0.01587857, 0.00935822, 0.00931327, -0.0128156, 0.05170929, -0.01811879, 0.02096679, 0.00897546, 0.00132624, -0.01796336, 0.01888563, -0.01142226, -0.00805926, 0.00049782, -0.02151541, 0.00747257, 0.023373, -0.00198183, 0.02968843, 0.00443042, -0.00328569, -0.04200815, 0.01306543, -0.01608924, -0.01604842, 0.03137267, 0.0266054, 0.00172526, -0.01205696, 0.00047532, 0.00321026, 0.00671424, 0.01710422, -0.01129941, 0.00268044, -0.01065434, -0.01107133, 0.00036135, -0.02991677, 0.02351665, -0.00343891, -0.01736755, -0.00100577, -0.00312481, -0.01083809, 0.00387084, 0.01136449, 0.01675043, -0.01978249, -0.00765182, 0.02746241, -0.01082247, -0.01587164, 0.01104732, -0.00878782, -0.00497555, -0.00186257, -0.02281011, 0.00141792, 0.00432851, -0.01290263, -0.00387155, 0.00802639, -0.00761913, 0.01508144, 0.02226428, 0.0107248, 0.01003709, 0.01587571, 0.00083492, -0.01632052, -0.00435973 });
INDArray expB = Nd4j.create(new double[] { -0.02465764, 0.00756337, -0.0268607, 0.01588023, 0.01580242, -0.00150542, 0.00116652, 0.0021577, -0.00754891, -0.02441176, -0.01271976, -0.02015191, 0.00220599, 0.03722657, -0.01629612, -0.02779619, -0.01157856, -0.01937938, -0.00744667, 0.01990043, -0.00505888, 0.00573646, 0.00385467, -0.0282531, 0.03484593, -0.05528606, 0.02428633, -0.01510474, 0.00153177, -0.03637344, 0.01747423, -0.00090738, -0.02199888, 0.01410434, -0.01710641, -0.01446697, -0.04225266, 0.00262217, 0.00871943, 0.00471594, 0.0101348, -0.01991908, 0.00874325, -0.00606416, -0.01035323, -0.01376545, 0.00451507, -0.01220307, -0.04361237, 0.00026028, -0.02401881, 0.00580314, 0.00238946, -0.01325974, 0.01879044, -0.00335623, -0.01631887, 0.02222102, -0.02998703, 0.03190075, -0.01675236, -0.01799807, -0.01314015, 0.01950069, 0.0011723, 0.01013178, 0.01093296, -0.034143, 0.00420227, 0.01449351, -0.00629987, 0.01652851, -0.01286825, 0.03314656, 0.03485073, 0.01120341, 0.01298241, 0.0019494, -0.02420256, -0.0063762, 0.01527091, -0.00732881, 0.0060427, 0.019327, -0.02068196, 0.00876712, 0.00292274, 0.01312969, -0.01529114, 0.0021757, -0.00565621, -0.01093122, 0.02758765, -0.01342688, 0.01606117, -0.02666447, 0.00541112, 0.00375426, -0.00761796, 0.00136015, -0.01169962, -0.03012749, 0.03012953, -0.05491332, -0.01137303, -0.01392103, 0.01370098, -0.00794501, 0.0248435, 0.00319645, 0.04261713, -0.00364211, 0.00780485, 0.01182583, -0.00647098, 0.03291231, -0.02515565, 0.03480943, 0.00119836, -0.00490694, 0.02615346, -0.00152456, 0.00196142, -0.02326461, 0.00603225, -0.02414703, -0.02540966, 0.0072112, -0.01090273, -0.00505061, -0.02196866, 0.00515245, 0.04981546, -0.02237269, -0.00189305, 0.0169786, 0.01782372, -0.00430022, 0.00551226, 0.00293861, -0.01337168, -0.00302476, -0.01869966, 0.00270757, 0.03199976, -0.01614617, -0.02716484, 0.01560035, -0.01312686, -0.01604082, 0.01347521, 0.03229654, 0.00707219, -0.00588392, 0.02444809, -0.01068742, -0.0190814, -0.00556385, -0.00462766, 0.01283929, 0.02001247, -0.00837629, -0.00041943, -0.02298774, 0.00874839, 0.00434907, -0.00963332, 0.00476905, 0.00793049, -0.00212557, -0.01839353, 0.03345517, 0.00838255, -0.0157447, -0.0376134, 0.01059611, -0.02323246, -0.01326356, -0.01116734, 0.00598869, 0.0211626, 0.01872963, -0.0038276, -0.01208279, -0.00989125, 0.04147648, 0.00181867, -0.00369355, 0.02312465, 0.0048396, 0.00564515, 0.01317832, -0.0057621, -0.01882041, -0.02869064, -0.00670661, 0.02585443, -0.01108428, 0.01411031, 0.01204507, -0.01244726, -0.00962342, -0.00205239, -0.01653971, 0.02871559, -0.00772978, 0.0214524, 0.02035478, -0.01324312, 0.00169302, -0.00064739, 0.00531795, 0.01059279, -0.02455794, -0.00002782, -0.0068906, -0.0160858, -0.0031842, -0.02295724, 0.01481094, 0.01769004, -0.02925742, 0.02050495, -0.00029003, -0.02815636, 0.02467367, 0.03419458, 0.00654938, -0.01847546, 0.00999932, 0.00059222, -0.01722176, 0.05172159, -0.01548486, 0.01746444, 0.007871, 0.0078471, -0.02414417, 0.01898077, -0.01470176, -0.00299465, 0.00368212, -0.02474656, 0.01317451, 0.03706085, -0.00032923, 0.02655881, 0.0013586, -0.0120303, -0.05030316, 0.0222294, -0.0070967, -0.02150935, 0.03254268, 0.01369857, 0.00246183, -0.02253576, -0.00551247, 0.00787363, 0.01215617, 0.02439827, -0.01104699, -0.00774596, -0.01898127, -0.01407653, 0.00195514, -0.03466602, 0.01560903, -0.01239944, -0.02474852, 0.00155114, 0.00089324, -0.01725949, -0.00011816, 0.00742845, 0.01247074, -0.02467943, -0.00679623, 0.01988366, -0.00626181, -0.02396477, 0.01052101, -0.01123178, -0.00386291, -0.00349261, -0.02714747, -0.00563315, 0.00228767, -0.01303677, -0.01971108, 0.00014759, -0.00346399, 0.02220698, 0.01979946, -0.00526076, 0.00647453, 0.01428513, 0.00223467, -0.01690172, -0.0081715 });
VectorsConfiguration configuration = new VectorsConfiguration();
configuration.setIterations(5);
configuration.setLearningRate(0.01);
configuration.setUseHierarchicSoftmax(true);
configuration.setNegative(0);
Word2Vec w2v = WordVectorSerializer.readWord2VecFromText(new File("/home/raver119/Downloads/gensim_models_for_dl4j/word"), new File("/home/raver119/Downloads/gensim_models_for_dl4j/hs"), new File("/home/raver119/Downloads/gensim_models_for_dl4j/hs_code"), new File("/home/raver119/Downloads/gensim_models_for_dl4j/hs_mapping"), configuration);
TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
assertNotEquals(null, w2v.getLookupTable());
assertNotEquals(null, w2v.getVocab());
ParagraphVectors d2v = new ParagraphVectors.Builder(configuration).useExistingWordVectors(w2v).sequenceLearningAlgorithm(new DM<VocabWord>()).tokenizerFactory(tokenizerFactory).resetModel(false).build();
assertNotEquals(null, d2v.getLookupTable());
assertNotEquals(null, d2v.getVocab());
assertTrue(d2v.getVocab() == w2v.getVocab());
assertTrue(d2v.getLookupTable() == w2v.getLookupTable());
String textA = "Donald Trump referred to President Obama as “your president” during the first presidential debate on Monday, much to many people’s chagrin on social media. Trump, made the reference after saying that the greatest threat facing the world is nuclear weapons. He then turned to Hillary Clinton and said, “Not global warming like you think and your President thinks,” referring to Obama.";
String textB = "The comment followed Trump doubling down on his false claims about the so-called birther conspiracy theory about Obama. People following the debate were immediately angered that Trump implied Obama is not his president.";
String textC = "practice of trust owned Trump for example indeed and conspiracy between provoke";
INDArray arrayA = d2v.inferVector(textA);
INDArray arrayB = d2v.inferVector(textB);
INDArray arrayC = d2v.inferVector(textC);
assertNotEquals(null, arrayA);
assertNotEquals(null, arrayB);
Transforms.unitVec(arrayA);
Transforms.unitVec(arrayB);
Transforms.unitVec(expA);
Transforms.unitVec(expB);
double simX = Transforms.cosineSim(arrayA, arrayB);
double simC = Transforms.cosineSim(arrayA, arrayC);
double simB = Transforms.cosineSim(arrayB, expB);
log.info("SimilarityX: {}", simX);
log.info("SimilarityC: {}", simC);
log.info("SimilarityB: {}", simB);
}
use of org.deeplearning4j.models.word2vec.Word2Vec in project deeplearning4j by deeplearning4j.
the class Word2VecDataSetIteratorTest method testIterator1.
/**
* Basically all we want from this test - being able to finish without exceptions.
*/
@Test
public void testIterator1() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = // we make sure we'll have some missing words
new Word2Vec.Builder().minWordFrequency(10).iterations(1).learningRate(0.025).layerSize(150).seed(42).sampling(0).negativeSample(0).useHierarchicSoftmax(true).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).iterate(iter).workers(8).tokenizerFactory(t).elementsLearningAlgorithm(new CBOW<VocabWord>()).build();
vec.fit();
List<String> labels = new ArrayList<>();
labels.add("positive");
labels.add("negative");
Word2VecDataSetIterator iterator = new Word2VecDataSetIterator(vec, getLASI(iter, labels), labels, 1);
INDArray array = iterator.next().getFeatures();
while (iterator.hasNext()) {
DataSet ds = iterator.next();
assertArrayEquals(array.shape(), ds.getFeatureMatrix().shape());
}
}
Aggregations