use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class FirstIterationFunctionAdapter method call.
@Override
public Iterable<Map.Entry<VocabWord, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> pairIter) {
while (pairIter.hasNext()) {
List<Pair<List<VocabWord>, Long>> batch = new ArrayList<>();
while (pairIter.hasNext() && batch.size() < batchSize) {
Tuple2<List<VocabWord>, Long> pair = pairIter.next();
List<VocabWord> vocabWordsList = pair._1();
Long sentenceCumSumCount = pair._2();
batch.add(Pair.of(vocabWordsList, sentenceCumSumCount));
}
for (int i = 0; i < iterations; i++) {
//System.out.println("Training sentence: " + vocabWordsList);
for (Pair<List<VocabWord>, Long> pair : batch) {
List<VocabWord> vocabWordsList = pair.getKey();
Long sentenceCumSumCount = pair.getValue();
double currentSentenceAlpha = Math.max(minAlpha, alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount));
trainSentence(vocabWordsList, currentSentenceAlpha);
}
}
}
return indexSyn0VecMap.entrySet();
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class SecondIterationFunctionAdapter method skipGram.
public void skipGram(int ithWordInSentence, List<VocabWord> vocabWordsList, int b, double currentSentenceAlpha) {
VocabWord currentWord = vocabWordsList.get(ithWordInSentence);
if (currentWord != null && !vocabWordsList.isEmpty()) {
int end = window * 2 + 1 - b;
for (int a = b; a < end; a++) {
if (a != window) {
int c = ithWordInSentence - window + a;
if (c >= 0 && c < vocabWordsList.size()) {
VocabWord lastWord = vocabWordsList.get(c);
iterateSample(currentWord, lastWord, currentSentenceAlpha);
}
}
}
}
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class KeySequenceConvertFunction method call.
@Override
public Sequence<VocabWord> call(Tuple2<String, String> pair) throws Exception {
Sequence<VocabWord> sequence = new Sequence<>();
sequence.addSequenceLabel(new VocabWord(1.0, pair._1()));
if (tokenizerFactory == null)
instantiateTokenizerFactory();
List<String> tokens = tokenizerFactory.create(pair._2()).getTokens();
for (String token : tokens) {
if (token == null || token.isEmpty())
continue;
VocabWord word = new VocabWord(1.0, token);
sequence.addElement(word);
}
return sequence;
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class SecondIterationFunctionAdapter method call.
@Override
public Iterable<Entry<VocabWord, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> pairIter) {
this.vocabHolder = VocabHolder.getInstance();
this.vocabHolder.setSeed(seed, vectorLength);
if (negative > 0) {
negativeHolder = NegativeHolder.getInstance();
negativeHolder.initHolder(vocab, expTable, this.vectorLength);
}
while (pairIter.hasNext()) {
List<Pair<List<VocabWord>, Long>> batch = new ArrayList<>();
while (pairIter.hasNext() && batch.size() < batchSize) {
Tuple2<List<VocabWord>, Long> pair = pairIter.next();
List<VocabWord> vocabWordsList = pair._1();
Long sentenceCumSumCount = pair._2();
batch.add(Pair.of(vocabWordsList, sentenceCumSumCount));
}
for (int i = 0; i < iterations; i++) {
//System.out.println("Training sentence: " + vocabWordsList);
for (Pair<List<VocabWord>, Long> pair : batch) {
List<VocabWord> vocabWordsList = pair.getKey();
Long sentenceCumSumCount = pair.getValue();
double currentSentenceAlpha = Math.max(minAlpha, alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount));
trainSentence(vocabWordsList, currentSentenceAlpha);
}
}
}
return vocabHolder.getSplit(vocab);
}
use of org.deeplearning4j.models.word2vec.VocabWord in project deeplearning4j by deeplearning4j.
the class SentenceBatch method skipGram.
/**
* Train via skip gram
* @param i the current word
* @param sentence the sentence to train on
* @param b
* @param alpha the learning rate
*/
public void skipGram(Word2VecParam param, int i, List<VocabWord> sentence, int b, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
final VocabWord word = sentence.get(i);
int window = param.getWindow();
if (word != null && !sentence.isEmpty()) {
int end = window * 2 + 1 - b;
for (int a = b; a < end; a++) {
if (a != window) {
int c = i - window + a;
if (c >= 0 && c < sentence.size()) {
VocabWord lastWord = sentence.get(c);
iterateSample(param, word, lastWord, alpha, changed);
}
}
}
}
}
Aggregations