use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class PartitionTrainingFunction method call.
@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
/**
* first we initialize
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm != null)
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm != null)
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
// now we roll throw Sequences and prepare/convert/"learn" them
while (sequenceIterator.hasNext()) {
Sequence<T> sequence = sequenceIterator.next();
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
sequences.add(mergedSequence);
if (sequences.size() >= 8) {
trainAllAtOnce(sequences);
sequences.clear();
}
}
if (sequences.size() > 0) {
// finishing training round, to make sure we don't have trails
trainAllAtOnce(sequences);
sequences.clear();
}
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class BaseSparkLearningAlgorithm method applySubsampling.
public static Sequence<ShallowSequenceElement> applySubsampling(@NonNull Sequence<ShallowSequenceElement> sequence, @NonNull AtomicLong nextRandom, long totalElementsCount, double prob) {
Sequence<ShallowSequenceElement> result = new Sequence<>();
// subsampling implementation, if subsampling threshold met, just continue to next element
if (prob > 0) {
result.setSequenceId(sequence.getSequenceId());
if (sequence.getSequenceLabels() != null)
result.setSequenceLabels(sequence.getSequenceLabels());
if (sequence.getSequenceLabel() != null)
result.setSequenceLabel(sequence.getSequenceLabel());
for (ShallowSequenceElement element : sequence.getElements()) {
double numWords = (double) totalElementsCount;
double ran = (Math.sqrt(element.getElementFrequency() / (prob * numWords)) + 1) * (prob * numWords) / element.getElementFrequency();
nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11));
if (ran < (nextRandom.get() & 0xFFFF) / (double) 65536) {
continue;
}
result.addElement(element);
}
return result;
} else
return sequence;
}
use of org.deeplearning4j.models.sequencevectors.sequence.Sequence in project deeplearning4j by deeplearning4j.
the class SparkWord2Vec method fitSentences.
public void fitSentences(JavaRDD<String> sentences) {
/**
* Basically all we want here is tokenization, to get JavaRDD<Sequence<VocabWord>> out of Strings, and then we just go for SeqVec
*/
validateConfiguration();
final JavaSparkContext context = new JavaSparkContext(sentences.context());
broadcastEnvironment(context);
JavaRDD<Sequence<VocabWord>> seqRdd = sentences.map(new TokenizerFunction(configurationBroadcast));
// now since we have new rdd - just pass it to SeqVec
super.fitSequences(seqRdd);
}
Aggregations