use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readTextModel.
/**
* @param modelFile
* @return
* @throws FileNotFoundException
* @throws IOException
* @throws NumberFormatException
*/
private static Word2Vec readTextModel(File modelFile) throws IOException, NumberFormatException {
InMemoryLookupTable lookupTable;
VocabCache cache;
INDArray syn0;
Word2Vec ret = new Word2Vec();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(GzipUtils.isCompressedFilename(modelFile.getName()) ? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile), "UTF-8"))) {
String line = reader.readLine();
String[] initial = line.split(" ");
int words = Integer.parseInt(initial[0]);
int layerSize = Integer.parseInt(initial[1]);
syn0 = Nd4j.create(words, layerSize);
cache = new InMemoryLookupCache(false);
int currLine = 0;
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
assert split.length == layerSize + 1;
String word = split[0].replaceAll(whitespaceReplacement, " ");
float[] vector = new float[split.length - 1];
for (int i = 1; i < split.length; i++) {
vector[i - 1] = Float.parseFloat(split[i]);
}
syn0.putRow(currLine, Nd4j.create(vector));
cache.addWordToIndex(cache.numWords(), word);
cache.addToken(new VocabWord(1, word));
cache.putVocabWord(word);
currLine++;
}
lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().cache(cache).vectorLength(layerSize).build();
lookupTable.setSyn0(syn0);
ret.setVocab(cache);
ret.setLookupTable(lookupTable);
}
return ret;
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class WordVectorSerializer method readWord2VecFromText.
/**
* This method allows you to read ParagraphVectors from externaly originated vectors and syn1.
* So, technically this method is compatible with any other w2v implementation
*
* @param vectors text file with words and their wieghts, aka Syn0
* @param hs text file HS layers, aka Syn1
* @param h_codes text file with Huffman tree codes
* @param h_points text file with Huffman tree points
* @return
*/
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes, @NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
// first we load syn0
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors);
InMemoryLookupTable lookupTable = pair.getFirst();
lookupTable.setNegative(configuration.getNegative());
if (configuration.getNegative() > 0)
lookupTable.initNegative();
VocabCache<VocabWord> vocab = (VocabCache<VocabWord>) pair.getSecond();
// now we load syn1
BufferedReader reader = new BufferedReader(new FileReader(hs));
String line = null;
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
double[] array = new double[split.length];
for (int i = 0; i < split.length; i++) {
array[i] = Double.parseDouble(split[i]);
}
rows.add(Nd4j.create(array));
}
reader.close();
// it's possible to have full model without syn1
if (rows.size() > 0) {
INDArray syn1 = Nd4j.vstack(rows);
lookupTable.setSyn1(syn1);
}
// now we transform mappings into huffman tree points
reader = new BufferedReader(new FileReader(h_points));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0]));
List<Integer> points = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
points.add(Integer.parseInt(split[i]));
}
word.setPoints(points);
}
reader.close();
// now we transform mappings into huffman tree codes
reader = new BufferedReader(new FileReader(h_codes));
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
VocabWord word = vocab.wordFor(decodeB64(split[0]));
List<Byte> codes = new ArrayList<>();
for (int i = 1; i < split.length; i++) {
codes.add(Byte.parseByte(split[i]));
}
word.setCodes(codes);
word.setCodeLength((short) codes.size());
}
reader.close();
Word2Vec.Builder builder = new Word2Vec.Builder(configuration).vocabCache(vocab).lookupTable(lookupTable).resetModel(false);
TokenizerFactory factory = getTokenizerFactory(configuration);
if (factory != null)
builder.tokenizerFactory(factory);
Word2Vec w2v = builder.build();
return w2v;
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class WordVectorSerializerTest method testIndexPersistence.
@Test
public void testIndexPersistence() throws Exception {
File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100).stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5).iterate(iter).tokenizerFactory(t).build();
vec.fit();
VocabCache orig = vec.getVocab();
File tempFile = File.createTempFile("temp", "w2v");
tempFile.deleteOnExit();
WordVectorSerializer.writeWordVectors(vec, tempFile);
WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);
VocabCache rest = vec2.vocab();
assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());
for (VocabWord word : vec.getVocab().vocabWords()) {
INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());
assertEquals(array1, array2);
}
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class Glove method train.
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> rdd) throws Exception {
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map<String, Object> tokenizerVarMap = new HashMap<String, Object>() {
{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}
};
Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
TextPipeline pipeline = new TextPipeline(rdd, broadcastTokenizerVarMap);
pipeline.buildVocabCache();
pipeline.buildVocabWordListRDD();
// Get total word count
Long totalWordCount = pipeline.getTotalWordCount();
VocabCache<VocabWord> vocabCache = pipeline.getVocabCache();
JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = pipeline.getSentenceWordsCountRDD();
final Pair<VocabCache<VocabWord>, Long> vocabAndNumWords = new Pair<>(vocabCache, totalWordCount);
vocabCacheBroadcast = sc.broadcast(vocabAndNumWords.getFirst());
final GloveWeightLookupTable gloveWeightLookupTable = new GloveWeightLookupTable.Builder().cache(vocabAndNumWords.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75)).build();
gloveWeightLookupTable.resetWeights();
gloveWeightLookupTable.getBiasAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().rows());
gloveWeightLookupTable.getWeightAdaGrad().historicalGradient = Nd4j.ones(gloveWeightLookupTable.getSyn0().shape());
log.info("Created lookup table of size " + Arrays.toString(gloveWeightLookupTable.getSyn0().shape()));
CounterMap<String, String> coOccurrenceCounts = sentenceWordsCountRDD.map(new CoOccurrenceCalculator(symmetric, vocabCacheBroadcast, windowSize)).fold(new CounterMap<String, String>(), new CoOccurrenceCounts());
Iterator<Pair<String, String>> pair2 = coOccurrenceCounts.getPairIterator();
List<Triple<String, String, Double>> counts = new ArrayList<>();
while (pair2.hasNext()) {
Pair<String, String> next = pair2.next();
if (coOccurrenceCounts.getCount(next.getFirst(), next.getSecond()) > gloveWeightLookupTable.getMaxCount()) {
coOccurrenceCounts.setCount(next.getFirst(), next.getSecond(), gloveWeightLookupTable.getMaxCount());
}
counts.add(new Triple<>(next.getFirst(), next.getSecond(), coOccurrenceCounts.getCount(next.getFirst(), next.getSecond())));
}
log.info("Calculated co occurrences");
JavaRDD<Triple<String, String, Double>> parallel = sc.parallelize(counts);
JavaPairRDD<String, Tuple2<String, Double>> pairs = parallel.mapToPair(new PairFunction<Triple<String, String, Double>, String, Tuple2<String, Double>>() {
@Override
public Tuple2<String, Tuple2<String, Double>> call(Triple<String, String, Double> stringStringDoubleTriple) throws Exception {
return new Tuple2<>(stringStringDoubleTriple.getFirst(), new Tuple2<>(stringStringDoubleTriple.getSecond(), stringStringDoubleTriple.getThird()));
}
});
JavaPairRDD<VocabWord, Tuple2<VocabWord, Double>> pairsVocab = pairs.mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Double>>, VocabWord, Tuple2<VocabWord, Double>>() {
@Override
public Tuple2<VocabWord, Tuple2<VocabWord, Double>> call(Tuple2<String, Tuple2<String, Double>> stringTuple2Tuple2) throws Exception {
VocabWord w1 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._1());
VocabWord w2 = vocabCacheBroadcast.getValue().wordFor(stringTuple2Tuple2._2()._1());
return new Tuple2<>(w1, new Tuple2<>(w2, stringTuple2Tuple2._2()._2()));
}
});
for (int i = 0; i < iterations; i++) {
JavaRDD<GloveChange> change = pairsVocab.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Double>>, GloveChange>() {
@Override
public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Double>> vocabWordTuple2Tuple2) throws Exception {
VocabWord w1 = vocabWordTuple2Tuple2._1();
VocabWord w2 = vocabWordTuple2Tuple2._2()._1();
INDArray w1Vector = gloveWeightLookupTable.getSyn0().slice(w1.getIndex());
INDArray w2Vector = gloveWeightLookupTable.getSyn0().slice(w2.getIndex());
INDArray bias = gloveWeightLookupTable.getBias();
double score = vocabWordTuple2Tuple2._2()._2();
double xMax = gloveWeightLookupTable.getxMax();
double maxCount = gloveWeightLookupTable.getMaxCount();
//w1 * w2 + bias
double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
prediction += bias.getDouble(w1.getIndex()) + bias.getDouble(w2.getIndex());
double weight = FastMath.pow(Math.min(1.0, (score / maxCount)), xMax);
double fDiff = score > xMax ? prediction : weight * (prediction - Math.log(score));
if (Double.isNaN(fDiff))
fDiff = Nd4j.EPS_THRESHOLD;
//amount of change
double gradient = fDiff;
Pair<INDArray, Double> w1Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w1, w1Vector, w2Vector, gradient);
Pair<INDArray, Double> w2Update = update(gloveWeightLookupTable.getWeightAdaGrad(), gloveWeightLookupTable.getBiasAdaGrad(), gloveWeightLookupTable.getSyn0(), gloveWeightLookupTable.getBias(), w2, w2Vector, w1Vector, gradient);
return new GloveChange(w1, w2, w1Update.getFirst(), w2Update.getFirst(), w1Update.getSecond(), w2Update.getSecond(), fDiff, gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w1.getIndex()), gloveWeightLookupTable.getWeightAdaGrad().getHistoricalGradient().slice(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w2.getIndex()), gloveWeightLookupTable.getBiasAdaGrad().getHistoricalGradient().getDouble(w1.getIndex()));
}
});
List<GloveChange> gloveChanges = change.collect();
double error = 0.0;
for (GloveChange change2 : gloveChanges) {
change2.apply(gloveWeightLookupTable);
error += change2.getError();
}
List l = pairsVocab.collect();
Collections.shuffle(l);
pairsVocab = sc.parallelizePairs(l);
log.info("Error at iteration " + i + " was " + error);
}
return new Pair<>(vocabAndNumWords.getFirst(), gloveWeightLookupTable);
}
use of org.deeplearning4j.models.word2vec.wordstore.VocabCache in project deeplearning4j by deeplearning4j.
the class CoOccurrenceCalculator method call.
@Override
public CounterMap<String, String> call(Pair<List<String>, AtomicLong> pair) throws Exception {
List<String> sentence = pair.getFirst();
CounterMap<String, String> coOCurreneCounts = new CounterMap<>();
VocabCache vocab = this.vocab.value();
for (int i = 0; i < sentence.size(); i++) {
int wordIdx = vocab.indexOf(sentence.get(i));
String w1 = ((VocabWord) vocab.wordFor(sentence.get(i))).getWord();
if (// || w1.equals(Glove.UNK))
wordIdx < 0)
continue;
int windowStop = Math.min(i + windowSize + 1, sentence.size());
for (int j = i; j < windowStop; j++) {
int otherWord = vocab.indexOf(sentence.get(j));
String w2 = ((VocabWord) vocab.wordFor(sentence.get(j))).getWord();
if (// || w2.equals(Glove.UNK))
vocab.indexOf(sentence.get(j)) < 0)
continue;
if (otherWord == wordIdx)
continue;
if (wordIdx < otherWord) {
coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
if (symmetric)
coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
} else {
coOCurreneCounts.incrementCount(sentence.get(j), sentence.get(i), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
if (symmetric)
coOCurreneCounts.incrementCount(sentence.get(i), sentence.get(j), 1.0 / (j - i + Nd4j.EPS_THRESHOLD));
}
}
}
return coOCurreneCounts;
}
Aggregations