Search in sources :

Example 1 with IteratorMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ComputationGraphTestRNN method checkMaskArrayClearance.

@Test
public void checkMaskArrayClearance() {
    for (boolean tbptt : new boolean[] { true, false }) {
        //Simple "does it throw an exception" type test...
        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).graphBuilder().addInputs("in").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in").setOutputs("out").backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
        ComputationGraph net = new ComputationGraph(conf);
        net.init();
        MultiDataSet data = new MultiDataSet(new INDArray[] { Nd4j.linspace(1, 10, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.linspace(2, 20, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.ones(10) }, new INDArray[] { Nd4j.ones(10) });
        net.fit(data);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        DataSet ds = new DataSet(data.getFeatures(0), data.getLabels(0), data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
        net.fit(ds);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        MultiDataSetIterator iter = new IteratorMultiDataSetIterator(Collections.singletonList((org.nd4j.linalg.dataset.api.MultiDataSet) data).iterator(), 1);
        net.fit(iter);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        DataSetIterator iter2 = new IteratorDataSetIterator(Collections.singletonList(ds).iterator(), 1);
        net.fit(iter2);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
    }
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) Layer(org.deeplearning4j.nn.api.Layer) RnnOutputLayer(org.deeplearning4j.nn.conf.layers.RnnOutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) BaseRecurrentLayer(org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) Test(org.junit.Test)

Example 2 with IteratorMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.

@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0, 0.0));
    }
    //Does batching where appropriate
    MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Tuple2<Integer, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        MultiDataSet ds = iter.next();
        double score = network.score(ds, false);
        int numExamples = ds.getFeatures(0).size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    return out;
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) ArrayList(java.util.ArrayList) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Tuple2(scala.Tuple2) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 3 with IteratorMultiDataSetIterator

use of org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats)
            s.logReturnTime();
        //Sometimes: no data
        return Collections.emptyList();
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        if (stats)
            s.logInitialModelBefore();
        ComputationGraph net = worker.getInitialModelGraph();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            MultiDataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.getFeatures(0).size(0));
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncMultiDataSetIterator) {
            ((AsyncMultiDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Aggregations

IteratorMultiDataSetIterator (org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator)3 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)2 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)2 ArrayList (java.util.ArrayList)1 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)1 IteratorDataSetIterator (org.deeplearning4j.datasets.iterator.IteratorDataSetIterator)1 Layer (org.deeplearning4j.nn.api.Layer)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 RnnOutputLayer (org.deeplearning4j.nn.conf.layers.RnnOutputLayer)1 BaseRecurrentLayer (org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer)1 WorkerConfiguration (org.deeplearning4j.spark.api.WorkerConfiguration)1 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)1 StatsCalculationHelper (org.deeplearning4j.spark.api.stats.StatsCalculationHelper)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)1