Search in sources :

Example 1 with AsyncDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     * This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
     *
     * @param source
     */
public synchronized void fit(@NonNull DataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
            // if if we're using MQ here - we'd like
            if (isMQ)
                Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    }
    source.reset();
    DataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        if (isMQ) {
            if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
                log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
            MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
            iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
        } else
            iterator = new AsyncDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    int whiles = 0;
    while (iterator.hasNext() && !stopFit.get()) {
        whiles++;
        DataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as DataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        if (zoo == null)
            throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
        zoo[pos].feedDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof MultiLayerNetwork) {
                    if (averageUpdaters) {
                        Updater updater = ((MultiLayerNetwork) model).getUpdater();
                        int batchSize = 0;
                        if (updater != null && updater.getStateViewArray() != null) {
                            if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                                List<INDArray> updaters = new ArrayList<>();
                                for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    updaters.add(workerModel.getUpdater().getStateViewArray());
                                    batchSize += workerModel.batchSize();
                                }
                                Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
                            } else {
                                INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
                                int cnt = 0;
                                for (; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    state.addi(workerModel.getUpdater().getStateViewArray().dup());
                                    batchSize += workerModel.batchSize();
                                }
                                state.divi(cnt);
                                updater.setStateViewArray((MultiLayerNetwork) model, state, false);
                            }
                        }
                    }
                    ((MultiLayerNetwork) model).setScore(score);
                } else if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                }
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataSet(org.nd4j.linalg.dataset.api.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 2 with AsyncDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    final boolean isGraph = dataConfig.isGraphNetwork();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResultNoData());
        }
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        MultiLayerNetwork net = null;
        ComputationGraph graph = null;
        if (stats)
            s.logInitialModelBefore();
        if (isGraph)
            graph = worker.getInitialModelGraph();
        else
            net = worker.getInitialModel();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            DataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.numExamples());
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result;
                if (isGraph)
                    result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
                else
                    result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result;
                if (isGraph)
                    result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
                else
                    result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair;
            if (isGraph)
                pair = worker.getFinalResultWithStats(graph);
            else
                pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            if (isGraph)
                return Collections.singletonList(worker.getFinalResult(graph));
            else
                return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncDataSetIterator) {
            ((AsyncDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Example 3 with AsyncDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.

the class MagicQueueTest method testSequentialIterable.

@Test
public void testSequentialIterable() throws Exception {
    List<DataSet> list = new ArrayList<>();
    for (int i = 0; i < 1024; i++) list.add(new DataSet(Nd4j.create(new float[] { 1f, 2f, 3f }), Nd4j.create(new float[] { 1f, 2f, 3f })));
    int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
    ExistingDataSetIterator edsi = new ExistingDataSetIterator(list);
    MagicQueue queue = new MagicQueue.Builder().setMode(MagicQueue.Mode.SEQUENTIAL).setCapacityPerFlow(32).build();
    AsyncDataSetIterator adsi = new AsyncDataSetIterator(edsi, 10, queue);
    int cnt = 0;
    while (adsi.hasNext()) {
        DataSet ds = adsi.next();
        // making sure dataset isn't null
        assertNotEquals("Failed on round " + cnt, null, ds);
        // making sure device for this array is a "next one"
        assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getFeatures()).intValue());
        assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getLabels()).intValue());
        cnt++;
    }
    assertEquals(list.size(), cnt);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) ExistingDataSetIterator(org.deeplearning4j.datasets.iterator.ExistingDataSetIterator) ArrayList(java.util.ArrayList) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) Test(org.junit.Test)

Example 4 with AsyncDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method fit.

@Override
public void fit(DataSetIterator iterator) {
    DataSetIterator iter;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        iter = new AsyncDataSetIterator(iterator, 2);
    } else {
        iter = iterator;
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (layerWiseConfigurations.isPretrain()) {
        pretrain(iter);
        if (iter.resetSupported()) {
            iter.reset();
        }
    //            while (iter.hasNext()) {
    //                DataSet next = iter.next();
    //                if (next.getFeatureMatrix() == null || next.getLabels() == null)
    //                    break;
    //                setInput(next.getFeatureMatrix());
    //                setLabels(next.getLabels());
    //                finetune();
    //            }
    }
    if (layerWiseConfigurations.isBackprop()) {
        update(TaskUtils.buildTask(iter));
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        while (iter.hasNext()) {
            DataSet next = iter.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
            } else {
                if (hasMaskArrays)
                    setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                setInput(next.getFeatureMatrix());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays)
                clearLayerMaskArrays();
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    } else if (layerWiseConfigurations.isPretrain()) {
        log.warn("Warning: finetune is not applied.");
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Example 5 with AsyncDataSetIterator

use of org.deeplearning4j.datasets.iterator.AsyncDataSetIterator in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a DataSetIterator.
     * Note that this method can only be used with ComputationGraphs with 1 input and 1 output
     */
public void fit(DataSetIterator iterator) {
    if (flattenedGradients == null)
        initGradientsView();
    if (numInputArrays != 1 || numOutputArrays != 1)
        throw new UnsupportedOperationException("Cannot train ComputationGraph network with " + " multiple inputs or outputs using a DataSetIterator");
    DataSetIterator dataSetIterator;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        dataSetIterator = new AsyncDataSetIterator(iterator, 2);
    } else
        dataSetIterator = iterator;
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (configuration.isPretrain()) {
        pretrain(dataSetIterator);
    }
    if (configuration.isBackprop()) {
        update(TaskUtils.buildTask(dataSetIterator));
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (hasMaskArrays) {
                INDArray[] fMask = (next.getFeaturesMaskArray() != null ? new INDArray[] { next.getFeaturesMaskArray() } : null);
                INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] { next.getLabelsMaskArray() } : null);
                setLayerMaskArrays(fMask, lMask);
            }
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(new INDArray[] { next.getFeatures() }, new INDArray[] { next.getLabels() }, (hasMaskArrays ? new INDArray[] { next.getFeaturesMaskArray() } : null), (hasMaskArrays ? new INDArray[] { next.getLabelsMaskArray() } : null));
            } else {
                setInput(0, next.getFeatures());
                setLabel(0, next.getLabels());
                if (solver == null) {
                    solver = //TODO; don't like this
                    new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays) {
                clearLayerMaskArrays();
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.api.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Aggregations

AsyncDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)6 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 DataSet (org.nd4j.linalg.dataset.DataSet)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 Solver (org.deeplearning4j.optimize.Solver)2 TrainingListener (org.deeplearning4j.optimize.api.TrainingListener)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 DataSet (org.nd4j.linalg.dataset.api.DataSet)2 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)2 ArrayList (java.util.ArrayList)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 ExistingDataSetIterator (org.deeplearning4j.datasets.iterator.ExistingDataSetIterator)1 IteratorDataSetIterator (org.deeplearning4j.datasets.iterator.IteratorDataSetIterator)1 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)1 Updater (org.deeplearning4j.nn.api.Updater)1 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)1 WorkerConfiguration (org.deeplearning4j.spark.api.WorkerConfiguration)1