use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class SparkDBOW method frameSequence.
@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
if (vectorsConfiguration.getSampling() > 0)
sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
int currentWindow = vectorsConfiguration.getWindow();
if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
}
if (frame == null)
synchronized (this) {
if (frame == null)
frame = new ThreadLocal<>();
}
if (frame.get() == null)
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
for (ShallowSequenceElement lastWord : sequence.getSequenceLabels()) {
for (ShallowSequenceElement word : sequence.getElements()) {
iterateSample(word, lastWord, nextRandom, learningRate);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
}
}
// at this moment we should have something in ThreadLocal Frame, so we'll send it to VoidParameterServer for processing
Frame<SkipGramRequestMessage> currentFrame = frame.get();
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
return currentFrame;
}
Aggregations