Search in sources :

Example 1 with AsyncMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a MultiDataSetIterator
     */
public void fit(MultiDataSetIterator multi) {
    if (flattenedGradients == null)
        initGradientsView();
    MultiDataSetIterator multiDataSetIterator;
    if (multi.asyncSupported()) {
        multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
    } else
        multiDataSetIterator = multi;
    if (configuration.isPretrain()) {
        pretrain(multiDataSetIterator);
    }
    if (configuration.isBackprop()) {
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet next = multiDataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
            } else {
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                }
                setInputs(next.getFeatures());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
}
Also used : SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) Solver(org.deeplearning4j.optimize.Solver) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 2 with AsyncMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats)
            s.logReturnTime();
        //Sometimes: no data
        return Collections.emptyList();
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        if (stats)
            s.logInitialModelBefore();
        ComputationGraph net = worker.getInitialModelGraph();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            MultiDataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.getFeatures(0).size(0));
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncMultiDataSetIterator) {
            ((AsyncMultiDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 3 with AsyncMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     *
     * @param source
     */
public synchronized void fit(@NonNull MultiDataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            // we pass true here, to tell Trainer to use MultiDataSet queue for training
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    } else {
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt].useMDS = true;
        }
    }
    source.reset();
    MultiDataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    while (iterator.hasNext() && !stopFit.get()) {
        MultiDataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        zoo[pos].feedMultiDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                } else
                    throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 4 with AsyncMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ParameterServerParallelWrapper method fit.

public void fit(MultiDataSetIterator multiDataSetIterator) {
    if (!init)
        init(multiDataSetIterator);
    MultiDataSetIterator iterator = null;
    if (preFetchSize > 0 && multiDataSetIterator.asyncSupported()) {
        iterator = new AsyncMultiDataSetIterator(multiDataSetIterator, preFetchSize);
    } else
        iterator = multiDataSetIterator;
    while (iterator.hasNext()) {
        org.nd4j.linalg.dataset.api.MultiDataSet next = iterator.next();
        addObject(next);
    }
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Aggregations

AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)4 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)4 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 IteratorMultiDataSetIterator (org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator)1 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)1 Solver (org.deeplearning4j.optimize.Solver)1 WorkerConfiguration (org.deeplearning4j.spark.api.WorkerConfiguration)1 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)1 StatsCalculationHelper (org.deeplearning4j.spark.api.stats.StatsCalculationHelper)1 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1