use of org.deeplearning4j.text.sentenceiterator.BasicLineIterator in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testOutputStream.
@Test
public void testOutputStream() throws Exception {
File file = File.createTempFile("tmp_ser", "ssa");
file.deleteOnExit();
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = new BasicLineIterator(inputFile);
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
InMemoryLookupCache cache = new InMemoryLookupCache(false);
WeightLookupTable table = new InMemoryLookupTable.Builder().vectorLength(100).useAdaGrad(false).negative(5.0).cache(cache).lr(0.025f).build();
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).lookupTable(table).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).vocabCache(cache).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
assertEquals(new ArrayList<String>(), vec.getStopWords());
vec.fit();
INDArray day1 = vec.getWordVectorMatrix("day");
WordVectorSerializer.writeWordVectors(vec, new FileOutputStream(file));
WordVectors vec2 = WordVectorSerializer.loadTxtVectors(file);
INDArray day2 = vec2.getWordVectorMatrix("day");
assertEquals(day1, day2);
File tempFile = File.createTempFile("tetsts", "Fdfs");
tempFile.deleteOnExit();
WordVectorSerializer.writeWord2VecModel(vec, tempFile);
Word2Vec vec3 = WordVectorSerializer.readWord2VecModel(tempFile);
}
use of org.deeplearning4j.text.sentenceiterator.BasicLineIterator in project deeplearning4j by deeplearning4j.
the class Word2VecTests method testWord2VecGoogleModelUptraining.
@Ignore
@Test
public void testWord2VecGoogleModelUptraining() throws Exception {
long time1 = System.currentTimeMillis();
Word2Vec vec = WordVectorSerializer.readWord2VecModel(new File("C:\\Users\\raver\\Downloads\\GoogleNews-vectors-negative300.bin.gz"), false);
long time2 = System.currentTimeMillis();
log.info("Model loaded in {} msec", time2 - time1);
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
vec.setTokenizerFactory(t);
vec.setSentenceIterator(iter);
vec.getConfiguration().setUseHierarchicSoftmax(false);
vec.getConfiguration().setNegative(5.0);
vec.setElementsLearningAlgorithm(new CBOW<VocabWord>());
vec.fit();
}
use of org.deeplearning4j.text.sentenceiterator.BasicLineIterator in project deeplearning4j by deeplearning4j.
the class Word2VecTests method testWord2VecAdaGrad.
@Test
public void testWord2VecAdaGrad() throws Exception {
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(5).learningRate(0.025).layerSize(100).seed(42).batchSize(13500).sampling(0).negativeSample(0).windowSize(5).modelUtils(new BasicModelUtils<VocabWord>()).useAdaGrad(false).useHierarchicSoftmax(true).iterate(iter).workers(4).tokenizerFactory(t).build();
vec.fit();
Collection<String> lst = vec.wordsNearest("day", 10);
log.info(Arrays.toString(lst.toArray()));
// assertEquals(10, lst.size());
double sim = vec.similarity("day", "night");
log.info("Day/night similarity: " + sim);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
}
use of org.deeplearning4j.text.sentenceiterator.BasicLineIterator in project deeplearning4j by deeplearning4j.
the class Word2VecTests method testRunWord2Vec.
@Test
public void testRunWord2Vec() throws Exception {
// Strip white space before and after for each line
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100).stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001).sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).epochs(1).windowSize(5).allowParallelTokenization(true).modelUtils(new BasicModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();
assertEquals(new ArrayList<String>(), vec.getStopWords());
vec.fit();
File tempFile = File.createTempFile("temp", "temp");
tempFile.deleteOnExit();
WordVectorSerializer.writeFullModel(vec, tempFile.getAbsolutePath());
Collection<String> lst = vec.wordsNearest("day", 10);
//log.info(Arrays.toString(lst.toArray()));
printWords("day", lst, vec);
assertEquals(10, lst.size());
double sim = vec.similarity("day", "night");
log.info("Day/night similarity: " + sim);
assertTrue(sim < 1.0);
assertTrue(sim > 0.4);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
assertFalse(lst.contains(null));
lst = vec.wordsNearest("day", 10);
//log.info(Arrays.toString(lst.toArray()));
printWords("day", lst, vec);
assertTrue(lst.contains("week"));
assertTrue(lst.contains("night"));
assertTrue(lst.contains("year"));
new File("cache.ser").delete();
ArrayList<String> labels = new ArrayList<>();
labels.add("day");
labels.add("night");
labels.add("week");
INDArray matrix = vec.getWordVectors(labels);
assertEquals(matrix.getRow(0), vec.getWordVectorMatrix("day"));
assertEquals(matrix.getRow(1), vec.getWordVectorMatrix("night"));
assertEquals(matrix.getRow(2), vec.getWordVectorMatrix("week"));
WordVectorSerializer.writeWordVectors(vec, pathToWriteto);
}
use of org.deeplearning4j.text.sentenceiterator.BasicLineIterator in project deeplearning4j by deeplearning4j.
the class AbstractCoOccurrences method iterator.
/**
*
* This method returns iterator with elements pairs and their weights. Resulting iterator is safe to use in multi-threaded environment.
*
* Developer's note: thread safety on received iterator is delegated to PrefetchedSentenceIterator
* @return
*/
public Iterator<Pair<Pair<T, T>, Double>> iterator() {
final SentenceIterator iterator;
try {
iterator = new SynchronizedSentenceIterator(new PrefetchingSentenceIterator.Builder(new BasicLineIterator(targetFile)).setFetchSize(500000).build());
} catch (Exception e) {
logger.error("Target file was not found on last stage!");
throw new RuntimeException(e);
}
return new Iterator<Pair<Pair<T, T>, Double>>() {
/*
iterator should be built on top of current text file with all pairs
*/
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public Pair<Pair<T, T>, Double> next() {
String line = iterator.nextSentence();
String[] strings = line.split(" ");
T element1 = vocabCache.elementAtIndex(Integer.valueOf(strings[0]));
T element2 = vocabCache.elementAtIndex(Integer.valueOf(strings[1]));
Double weight = Double.valueOf(strings[2]);
return new Pair<>(new Pair<>(element1, element2), weight);
}
@Override
public void remove() {
throw new UnsupportedOperationException("remove() method can't be supported on read-only interface");
}
};
}
Aggregations