use of org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage in project deeplearning4j by deeplearning4j.
the class SparkSkipGram method iterateSample.
protected void iterateSample(ShallowSequenceElement word, ShallowSequenceElement lastWord, AtomicLong nextRandom, double lr) {
if (word == null || lastWord == null || lastWord.getIndex() < 0 || word.getIndex() == lastWord.getIndex())
return;
/**
* all we want here, is actually very simple:
* we just build simple SkipGram frame, and send it over network
*/
int[] idxSyn1 = new int[0];
byte[] codes = new byte[0];
if (vectorsConfiguration.isUseHierarchicSoftmax()) {
idxSyn1 = new int[word.getCodeLength()];
codes = new byte[word.getCodeLength()];
for (int i = 0; i < word.getCodeLength(); i++) {
byte code = word.getCodes().get(i);
int point = word.getPoints().get(i);
if (point >= vocabCache.numWords() || point < 0)
continue;
codes[i] = code;
idxSyn1[i] = point;
}
}
short neg = (short) vectorsConfiguration.getNegative();
SkipGramRequestMessage sgrm = new SkipGramRequestMessage(word.getIndex(), lastWord.getIndex(), idxSyn1, codes, neg, lr, nextRandom.get());
// we just stackfor now
frame.get().stackMessage(sgrm);
}
use of org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage in project deeplearning4j by deeplearning4j.
the class SparkSkipGram method frameSequence.
@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
// FIXME: totalElementsCount should have real value
if (vectorsConfiguration.getSampling() > 0)
sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
int currentWindow = vectorsConfiguration.getWindow();
if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
}
if (frame == null)
synchronized (this) {
if (frame == null)
frame = new ThreadLocal<>();
}
if (frame.get() == null)
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
for (int i = 0; i < sequence.size(); i++) {
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
ShallowSequenceElement word = sequence.getElementByIndex(i);
if (word == null)
continue;
int b = (int) (nextRandom.get() % currentWindow);
int end = currentWindow * 2 + 1 - b;
for (int a = b; a < end; a++) {
if (a != currentWindow) {
int c = i - currentWindow + a;
if (c >= 0 && c < sequence.size()) {
ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
iterateSample(word, lastWord, nextRandom, learningRate);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
}
}
}
}
// at this moment we should have something in ThreadLocal Frame, so we'll send it to VoidParameterServer for processing
Frame<SkipGramRequestMessage> currentFrame = frame.get();
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
return currentFrame;
}
use of org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage in project deeplearning4j by deeplearning4j.
the class SparkDBOW method frameSequence.
@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
if (vectorsConfiguration.getSampling() > 0)
sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
int currentWindow = vectorsConfiguration.getWindow();
if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
}
if (frame == null)
synchronized (this) {
if (frame == null)
frame = new ThreadLocal<>();
}
if (frame.get() == null)
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
for (ShallowSequenceElement lastWord : sequence.getSequenceLabels()) {
for (ShallowSequenceElement word : sequence.getElements()) {
iterateSample(word, lastWord, nextRandom, learningRate);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
}
}
// at this moment we should have something in ThreadLocal Frame, so we'll send it to VoidParameterServer for processing
Frame<SkipGramRequestMessage> currentFrame = frame.get();
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
return currentFrame;
}
Aggregations