use of org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener in project deeplearning4j by deeplearning4j.
the class SequenceVectors method fit.
/**
* Starts training over
*/
public void fit() {
Properties props = Nd4j.getExecutioner().getEnvironmentInformation();
if (props.getProperty("backend").equals("CUDA")) {
if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
throw new IllegalStateException("Multi-GPU word2vec/doc2vec isn't available atm");
//if (!NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())
//throw new IllegalStateException("Running Word2Vec on multi-gpu system requires P2P support between GPUs, which looks to be unavailable on your system.");
}
Nd4j.getRandom().setSeed(configuration.getSeed());
AtomicLong timeSpent = new AtomicLong(0);
if (!trainElementsVectors && !trainSequenceVectors)
throw new IllegalStateException("You should define at least one training goal 'trainElementsRepresentation' or 'trainSequenceRepresentation'");
if (iterator == null)
throw new IllegalStateException("You can't fit() data without SequenceIterator defined");
if (resetModel || (lookupTable != null && vocab != null && vocab.numWords() == 0)) {
// build vocabulary from scratches
buildVocab();
}
WordVectorSerializer.printOutProjectedMemoryUse(vocab.numWords(), configuration.getLayersSize(), configuration.isUseHierarchicSoftmax() && configuration.getNegative() > 0 ? 3 : 2);
if (vocab == null || lookupTable == null || vocab.numWords() == 0)
throw new IllegalStateException("You can't fit() model with empty Vocabulary or WeightLookupTable");
// if model vocab and lookupTable is built externally we basically should check that lookupTable was properly initialized
if (!resetModel || existingModel != null) {
lookupTable.resetWeights(false);
} else {
// otherwise we reset weights, independent of actual current state of lookup table
lookupTable.resetWeights(true);
// if preciseWeights used, we roll over data once again
if (configuration.isPreciseWeightInit()) {
log.info("Using precise weights init...");
iterator.reset();
while (iterator.hasMoreSequences()) {
Sequence<T> sequence = iterator.nextSequence();
// initializing elements, only once
for (T element : sequence.getElements()) {
T realElement = vocab.tokenFor(element.getLabel());
if (realElement != null && !realElement.isInit()) {
Random rng = Nd4j.getRandomFactory().getNewRandomInstance(configuration.getSeed() * realElement.hashCode(), configuration.getLayersSize() + 1);
INDArray randArray = Nd4j.rand(new int[] { 1, configuration.getLayersSize() }, rng).subi(0.5).divi(configuration.getLayersSize());
lookupTable.getWeights().getRow(realElement.getIndex()).assign(randArray);
realElement.setInit(true);
}
}
// initializing labels, only once
for (T label : sequence.getSequenceLabels()) {
T realElement = vocab.tokenFor(label.getLabel());
if (realElement != null && !realElement.isInit()) {
Random rng = Nd4j.getRandomFactory().getNewRandomInstance(configuration.getSeed() * realElement.hashCode(), configuration.getLayersSize() + 1);
INDArray randArray = Nd4j.rand(new int[] { 1, configuration.getLayersSize() }, rng).subi(0.5).divi(configuration.getLayersSize());
lookupTable.getWeights().getRow(realElement.getIndex()).assign(randArray);
realElement.setInit(true);
}
}
}
this.iterator.reset();
}
}
initLearners();
log.info("Starting learning process...");
timeSpent.set(System.currentTimeMillis());
if (this.stopWords == null)
this.stopWords = new ArrayList<>();
for (int currentEpoch = 1; currentEpoch <= numEpochs; currentEpoch++) {
final AtomicLong linesCounter = new AtomicLong(0);
final AtomicLong wordsCounter = new AtomicLong(0);
AsyncSequencer sequencer = new AsyncSequencer(this.iterator, this.stopWords);
sequencer.start();
//final VectorCalculationsThread[] threads = new VectorCalculationsThread[workers];
final AtomicLong timer = new AtomicLong(System.currentTimeMillis());
final List<VectorCalculationsThread> threads = new ArrayList<>();
for (int x = 0; x < workers; x++) {
threads.add(x, new VectorCalculationsThread(x, currentEpoch, wordsCounter, vocab.totalWordOccurrences(), linesCounter, sequencer, timer));
threads.get(x).start();
}
try {
sequencer.join();
} catch (Exception e) {
throw new RuntimeException(e);
}
for (int x = 0; x < workers; x++) {
try {
threads.get(x).join();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// TODO: fix this to non-exclusive termination
if (trainElementsVectors && elementsLearningAlgorithm != null && (!trainSequenceVectors || sequenceLearningAlgorithm == null) && elementsLearningAlgorithm.isEarlyTerminationHit()) {
break;
}
if (trainSequenceVectors && sequenceLearningAlgorithm != null && (!trainElementsVectors || elementsLearningAlgorithm == null) && sequenceLearningAlgorithm.isEarlyTerminationHit()) {
break;
}
log.info("Epoch: [" + currentEpoch + "]; Words vectorized so far: [" + wordsCounter.get() + "]; Lines vectorized so far: [" + linesCounter.get() + "]; learningRate: [" + minLearningRate + "]");
if (eventListeners != null && !eventListeners.isEmpty()) {
for (VectorsListener listener : eventListeners) {
if (listener.validateEvent(ListenerEvent.EPOCH, currentEpoch))
listener.processEvent(ListenerEvent.EPOCH, this, currentEpoch);
}
}
}
log.info("Time spent on training: {} ms", System.currentTimeMillis() - timeSpent.get());
}
Aggregations