Search in sources :

Example 1 with VectorsListener

use of org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener in project deeplearning4j by deeplearning4j.

the class SequenceVectors method fit.

/**
     * Starts training over
     */
public void fit() {
    Properties props = Nd4j.getExecutioner().getEnvironmentInformation();
    if (props.getProperty("backend").equals("CUDA")) {
        if (Nd4j.getAffinityManager().getNumberOfDevices() > 1)
            throw new IllegalStateException("Multi-GPU word2vec/doc2vec isn't available atm");
    //if (!NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())
    //throw new IllegalStateException("Running Word2Vec on multi-gpu system requires P2P support between GPUs, which looks to be unavailable on your system.");
    }
    Nd4j.getRandom().setSeed(configuration.getSeed());
    AtomicLong timeSpent = new AtomicLong(0);
    if (!trainElementsVectors && !trainSequenceVectors)
        throw new IllegalStateException("You should define at least one training goal 'trainElementsRepresentation' or 'trainSequenceRepresentation'");
    if (iterator == null)
        throw new IllegalStateException("You can't fit() data without SequenceIterator defined");
    if (resetModel || (lookupTable != null && vocab != null && vocab.numWords() == 0)) {
        // build vocabulary from scratches
        buildVocab();
    }
    WordVectorSerializer.printOutProjectedMemoryUse(vocab.numWords(), configuration.getLayersSize(), configuration.isUseHierarchicSoftmax() && configuration.getNegative() > 0 ? 3 : 2);
    if (vocab == null || lookupTable == null || vocab.numWords() == 0)
        throw new IllegalStateException("You can't fit() model with empty Vocabulary or WeightLookupTable");
    // if model vocab and lookupTable is built externally we basically should check that lookupTable was properly initialized
    if (!resetModel || existingModel != null) {
        lookupTable.resetWeights(false);
    } else {
        // otherwise we reset weights, independent of actual current state of lookup table
        lookupTable.resetWeights(true);
        // if preciseWeights used, we roll over data once again
        if (configuration.isPreciseWeightInit()) {
            log.info("Using precise weights init...");
            iterator.reset();
            while (iterator.hasMoreSequences()) {
                Sequence<T> sequence = iterator.nextSequence();
                // initializing elements, only once
                for (T element : sequence.getElements()) {
                    T realElement = vocab.tokenFor(element.getLabel());
                    if (realElement != null && !realElement.isInit()) {
                        Random rng = Nd4j.getRandomFactory().getNewRandomInstance(configuration.getSeed() * realElement.hashCode(), configuration.getLayersSize() + 1);
                        INDArray randArray = Nd4j.rand(new int[] { 1, configuration.getLayersSize() }, rng).subi(0.5).divi(configuration.getLayersSize());
                        lookupTable.getWeights().getRow(realElement.getIndex()).assign(randArray);
                        realElement.setInit(true);
                    }
                }
                // initializing labels, only once
                for (T label : sequence.getSequenceLabels()) {
                    T realElement = vocab.tokenFor(label.getLabel());
                    if (realElement != null && !realElement.isInit()) {
                        Random rng = Nd4j.getRandomFactory().getNewRandomInstance(configuration.getSeed() * realElement.hashCode(), configuration.getLayersSize() + 1);
                        INDArray randArray = Nd4j.rand(new int[] { 1, configuration.getLayersSize() }, rng).subi(0.5).divi(configuration.getLayersSize());
                        lookupTable.getWeights().getRow(realElement.getIndex()).assign(randArray);
                        realElement.setInit(true);
                    }
                }
            }
            this.iterator.reset();
        }
    }
    initLearners();
    log.info("Starting learning process...");
    timeSpent.set(System.currentTimeMillis());
    if (this.stopWords == null)
        this.stopWords = new ArrayList<>();
    for (int currentEpoch = 1; currentEpoch <= numEpochs; currentEpoch++) {
        final AtomicLong linesCounter = new AtomicLong(0);
        final AtomicLong wordsCounter = new AtomicLong(0);
        AsyncSequencer sequencer = new AsyncSequencer(this.iterator, this.stopWords);
        sequencer.start();
        //final VectorCalculationsThread[] threads = new VectorCalculationsThread[workers];
        final AtomicLong timer = new AtomicLong(System.currentTimeMillis());
        final List<VectorCalculationsThread> threads = new ArrayList<>();
        for (int x = 0; x < workers; x++) {
            threads.add(x, new VectorCalculationsThread(x, currentEpoch, wordsCounter, vocab.totalWordOccurrences(), linesCounter, sequencer, timer));
            threads.get(x).start();
        }
        try {
            sequencer.join();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        for (int x = 0; x < workers; x++) {
            try {
                threads.get(x).join();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        // TODO: fix this to non-exclusive termination
        if (trainElementsVectors && elementsLearningAlgorithm != null && (!trainSequenceVectors || sequenceLearningAlgorithm == null) && elementsLearningAlgorithm.isEarlyTerminationHit()) {
            break;
        }
        if (trainSequenceVectors && sequenceLearningAlgorithm != null && (!trainElementsVectors || elementsLearningAlgorithm == null) && sequenceLearningAlgorithm.isEarlyTerminationHit()) {
            break;
        }
        log.info("Epoch: [" + currentEpoch + "]; Words vectorized so far: [" + wordsCounter.get() + "];  Lines vectorized so far: [" + linesCounter.get() + "]; learningRate: [" + minLearningRate + "]");
        if (eventListeners != null && !eventListeners.isEmpty()) {
            for (VectorsListener listener : eventListeners) {
                if (listener.validateEvent(ListenerEvent.EPOCH, currentEpoch))
                    listener.processEvent(ListenerEvent.EPOCH, this, currentEpoch);
            }
        }
    }
    log.info("Time spent on training: {} ms", System.currentTimeMillis() - timeSpent.get());
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) Random(org.nd4j.linalg.api.rng.Random) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VectorsListener(org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener)

Aggregations

AtomicLong (java.util.concurrent.atomic.AtomicLong)1 VectorsListener (org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 Random (org.nd4j.linalg.api.rng.Random)1