use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class PartitionTrainingFunction method trainAllAtOnce.
protected void trainAllAtOnce(List<Sequence<ShallowSequenceElement>> sequences) {
Frame bigFrame = new Frame(BasicSequenceProvider.getInstance().getNextValue());
for (Sequence<ShallowSequenceElement> sequence : sequences) {
Frame frame = elementsLearningAlgorithm.frameSequence(sequence, new AtomicLong(119L), 25e-3f);
bigFrame.stackMessages(frame.getMessages());
}
if (bigFrame.size() > 0)
paramServer.execDistributed(bigFrame);
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class TrainingFunction method call.
@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
/**
* Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
/*
at this moment we should have everything ready for actual initialization
the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
*/
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
// FIXME: temporary hook
if (sequence.size() > 0)
paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
else
log.warn("Skipping empty sequence...");
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class SparkCBOW method frameSequence.
@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
// FIXME: totalElementsCount should have real value
if (vectorsConfiguration.getSampling() > 0)
sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
int currentWindow = vectorsConfiguration.getWindow();
if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
}
if (frame == null)
synchronized (this) {
if (frame == null)
frame = new ThreadLocal<>();
}
if (frame.get() == null)
frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
for (int i = 0; i < sequence.getElements().size(); i++) {
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
int b = (int) nextRandom.get() % currentWindow;
int end = currentWindow * 2 + 1 - b;
ShallowSequenceElement currentWord = sequence.getElementByIndex(i);
List<Integer> intsList = new ArrayList<>();
for (int a = b; a < end; a++) {
if (a != currentWindow) {
int c = i - currentWindow + a;
if (c >= 0 && c < sequence.size()) {
ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
intsList.add(lastWord.getIndex());
}
}
}
// just converting values to int
int[] windowWords = new int[intsList.size()];
for (int x = 0; x < windowWords.length; x++) {
windowWords[x] = intsList.get(x);
}
if (windowWords.length < 1)
continue;
iterateSample(currentWord, windowWords, nextRandom, learningRate, false, 0, true, null);
}
Frame<CbowRequestMessage> currentFrame = frame.get();
frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
return currentFrame;
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class SparkSkipGram method frameSequence.
@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
// FIXME: totalElementsCount should have real value
if (vectorsConfiguration.getSampling() > 0)
sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
int currentWindow = vectorsConfiguration.getWindow();
if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
}
if (frame == null)
synchronized (this) {
if (frame == null)
frame = new ThreadLocal<>();
}
if (frame.get() == null)
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
for (int i = 0; i < sequence.size(); i++) {
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
ShallowSequenceElement word = sequence.getElementByIndex(i);
if (word == null)
continue;
int b = (int) (nextRandom.get() % currentWindow);
int end = currentWindow * 2 + 1 - b;
for (int a = b; a < end; a++) {
if (a != currentWindow) {
int c = i - currentWindow + a;
if (c >= 0 && c < sequence.size()) {
ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
iterateSample(word, lastWord, nextRandom, learningRate);
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
}
}
}
}
// at this moment we should have something in ThreadLocal Frame, so we'll send it to VoidParameterServer for processing
Frame<SkipGramRequestMessage> currentFrame = frame.get();
frame.set(new Frame<SkipGramRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
return currentFrame;
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectors method buildShallowVocabCache.
/**
* This method builds shadow vocabulary and huffman tree
*
* @param counter
* @return
*/
protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
// TODO: need simplified cache here, that will operate on Long instead of string labels
VocabCache<ShallowSequenceElement> vocabCache = new AbstractCache<>();
for (Long id : counter.keySet()) {
ShallowSequenceElement shallowElement = new ShallowSequenceElement(counter.getCount(id), id);
vocabCache.addToken(shallowElement);
}
// building huffman tree
Huffman huffman = new Huffman(vocabCache.vocabWords());
huffman.build();
huffman.applyIndexes(vocabCache);
return vocabCache;
}
Aggregations