use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class SparkDM method frameSequence.
@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate) {
if (vectorsConfiguration.getSampling() > 0)
sequence = BaseSparkLearningAlgorithm.applySubsampling(sequence, nextRandom, 10L, vectorsConfiguration.getSampling());
int currentWindow = vectorsConfiguration.getWindow();
if (vectorsConfiguration.getVariableWindows() != null && vectorsConfiguration.getVariableWindows().length != 0) {
currentWindow = vectorsConfiguration.getVariableWindows()[RandomUtils.nextInt(vectorsConfiguration.getVariableWindows().length)];
}
if (frame == null)
synchronized (this) {
if (frame == null)
frame = new ThreadLocal<>();
}
if (frame.get() == null)
frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
for (int i = 0; i < sequence.getElements().size(); i++) {
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
int b = (int) nextRandom.get() % currentWindow;
int end = currentWindow * 2 + 1 - b;
ShallowSequenceElement currentWord = sequence.getElementByIndex(i);
List<Integer> intsList = new ArrayList<>();
for (int a = b; a < end; a++) {
if (a != currentWindow) {
int c = i - currentWindow + a;
if (c >= 0 && c < sequence.size()) {
ShallowSequenceElement lastWord = sequence.getElementByIndex(c);
intsList.add(lastWord.getIndex());
}
}
}
// basically it's the same as CBOW, we just add labels here
if (sequence.getSequenceLabels() != null) {
for (ShallowSequenceElement label : sequence.getSequenceLabels()) {
intsList.add(label.getIndex());
}
} else
// FIXME: we probably should throw this exception earlier?
throw new DL4JInvalidInputException("Sequence passed via RDD has no labels within, nothing to learn here");
// just converting values to int
int[] windowWords = new int[intsList.size()];
for (int x = 0; x < windowWords.length; x++) {
windowWords[x] = intsList.get(x);
}
if (windowWords.length < 1)
continue;
iterateSample(currentWord, windowWords, nextRandom, learningRate, false, 0, true, null);
}
Frame<CbowRequestMessage> currentFrame = frame.get();
frame.set(new Frame<CbowRequestMessage>(BasicSequenceProvider.getInstance().getNextValue()));
return currentFrame;
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class SparkSequenceVectorsTest method testFrequenciesCount.
@Test
public void testFrequenciesCount() throws Exception {
JavaRDD<Sequence<VocabWord>> sequences = sc.parallelize(sequencesCyclic);
SparkSequenceVectors<VocabWord> seqVec = new SparkSequenceVectors<>();
seqVec.fitSequences(sequences);
Counter<Long> counter = seqVec.getCounter();
// element "0" should have frequency of 20
assertEquals(20, counter.getCount(0L), 1e-5);
// elements 1 - 9 should have frequencies of 10
for (int e = 1; e < sequencesCyclic.get(0).getElements().size() - 1; e++) {
assertEquals(10, counter.getCount(sequencesCyclic.get(0).getElementByIndex(e).getStorageId()), 1e-5);
}
VocabCache<ShallowSequenceElement> shallowVocab = seqVec.getShallowVocabCache();
assertEquals(10, shallowVocab.numWords());
ShallowSequenceElement zero = shallowVocab.tokenFor(0L);
ShallowSequenceElement first = shallowVocab.tokenFor(1L);
assertNotEquals(null, zero);
assertEquals(20.0, zero.getElementFrequency(), 1e-5);
assertEquals(0, zero.getIndex());
assertEquals(10.0, first.getElementFrequency(), 1e-5);
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class DistributedFunction method call.
@Override
public ExportContainer<T> call(T word) throws Exception {
if (shallowVocabCache == null)
shallowVocabCache = shallowVocabBroadcast.getValue();
ExportContainer<T> container = new ExportContainer<>();
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(word.getStorageId());
word.setIndex(reduced.getIndex());
container.setElement(word);
container.setArray(VoidParameterServer.getInstance().getVector(reduced.getIndex()));
return container;
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class PartitionTrainingFunction method call.
@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
/**
* first we initialize
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm != null)
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm != null)
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
// now we roll throw Sequences and prepare/convert/"learn" them
while (sequenceIterator.hasNext()) {
Sequence<T> sequence = sequenceIterator.next();
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
sequences.add(mergedSequence);
if (sequences.size() >= 8) {
trainAllAtOnce(sequences);
sequences.clear();
}
}
if (sequences.size() > 0) {
// finishing training round, to make sure we don't have trails
trainAllAtOnce(sequences);
sequences.clear();
}
}
use of org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement in project deeplearning4j by deeplearning4j.
the class BaseSparkLearningAlgorithm method applySubsampling.
public static Sequence<ShallowSequenceElement> applySubsampling(@NonNull Sequence<ShallowSequenceElement> sequence, @NonNull AtomicLong nextRandom, long totalElementsCount, double prob) {
Sequence<ShallowSequenceElement> result = new Sequence<>();
// subsampling implementation, if subsampling threshold met, just continue to next element
if (prob > 0) {
result.setSequenceId(sequence.getSequenceId());
if (sequence.getSequenceLabels() != null)
result.setSequenceLabels(sequence.getSequenceLabels());
if (sequence.getSequenceLabel() != null)
result.setSequenceLabel(sequence.getSequenceLabel());
for (ShallowSequenceElement element : sequence.getElements()) {
double numWords = (double) totalElementsCount;
double ran = (Math.sqrt(element.getElementFrequency() / (prob * numWords)) + 1) * (prob * numWords) / element.getElementFrequency();
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
continue;
}
result.addElement(element);
}
return result;
} else
return sequence;
}
Aggregations