Search in sources :

Example 61 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class TestMasking method checkMaskArrayClearance.

@Test
public void checkMaskArrayClearance() {
    for (boolean tbptt : new boolean[] { true, false }) {
        //Simple "does it throw an exception" type test...
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).list().layer(0, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build()).backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        DataSet data = new DataSet(Nd4j.linspace(1, 10, 10).reshape(1, 1, 10), Nd4j.linspace(2, 20, 10).reshape(1, 1, 10), Nd4j.ones(10), Nd4j.ones(10));
        net.fit(data);
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(data).iterator());
        net.fit(iter);
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
    }
}
Also used : RnnOutputLayer(org.deeplearning4j.nn.conf.layers.RnnOutputLayer) DataSet(org.nd4j.linalg.dataset.DataSet) ExistingDataSetIterator(org.deeplearning4j.datasets.iterator.ExistingDataSetIterator) Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) RnnOutputLayer(org.deeplearning4j.nn.conf.layers.RnnOutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ExistingDataSetIterator(org.deeplearning4j.datasets.iterator.ExistingDataSetIterator) Test(org.junit.Test)

Example 62 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class TestOptimizers method testOptimizersBasicMLPBackprop.

@Test
public void testOptimizersBasicMLPBackprop() {
    //Basic tests of the 'does it throw an exception' variety.
    DataSetIterator iter = new IrisDataSetIterator(5, 50);
    OptimizationAlgorithm[] toTest = { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
    for (OptimizationAlgorithm oa : toTest) {
        MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa, 1));
        network.init();
        iter.reset();
        network.fit(iter);
    }
}
Also used : OptimizationAlgorithm(org.deeplearning4j.nn.api.OptimizationAlgorithm) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 63 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class Dl4jServingRouteTest method createRouteBuilder.

@Override
protected RouteBuilder createRouteBuilder() throws Exception {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    next = iter.next();
    next.normalizeZeroMeanZeroUnitVariance();
    return new RouteBuilder() {

        @Override
        public void configure() throws Exception {
            final String kafkaUri = String.format("kafka:%s?topic=%s&groupId=dl4j-serving", kafkaCluster.getBrokerList(), topicName);
            from("direct:start").process(new Processor() {

                @Override
                public void process(Exchange exchange) throws Exception {
                    final INDArray arr = next.getFeatureMatrix();
                    ByteArrayOutputStream bos = new ByteArrayOutputStream();
                    DataOutputStream dos = new DataOutputStream(bos);
                    Nd4j.write(arr, dos);
                    byte[] bytes = bos.toByteArray();
                    String base64 = Base64.encodeBase64String(bytes);
                    exchange.getIn().setBody(base64, String.class);
                    exchange.getIn().setHeader(KafkaConstants.KEY, UUID.randomUUID().toString());
                    exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, "1");
                }
            }).to(kafkaUri);
        }
    };
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) RouteBuilder(org.apache.camel.builder.RouteBuilder) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataOutputStream(java.io.DataOutputStream) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator)

Example 64 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    final boolean isGraph = dataConfig.isGraphNetwork();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResultNoData());
        }
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        MultiLayerNetwork net = null;
        ComputationGraph graph = null;
        if (stats)
            s.logInitialModelBefore();
        if (isGraph)
            graph = worker.getInitialModelGraph();
        else
            net = worker.getInitialModel();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            DataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.numExamples());
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result;
                if (isGraph)
                    result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
                else
                    result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result;
                if (isGraph)
                    result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
                else
                    result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair;
            if (isGraph)
                pair = worker.getFinalResultWithStats(graph);
            else
                pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            if (isGraph)
                return Collections.singletonList(worker.getFinalResult(graph));
            else
                return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncDataSetIterator) {
            ((AsyncDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Example 65 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class MultiLayerNetwork method fit.

@Override
public void fit(DataSetIterator iterator) {
    DataSetIterator iter;
    // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
    if (iterator.asyncSupported()) {
        iter = new AsyncDataSetIterator(iterator, 2);
    } else {
        iter = iterator;
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochStart(this);
        }
    }
    if (layerWiseConfigurations.isPretrain()) {
        pretrain(iter);
        if (iter.resetSupported()) {
            iter.reset();
        }
    //            while (iter.hasNext()) {
    //                DataSet next = iter.next();
    //                if (next.getFeatureMatrix() == null || next.getLabels() == null)
    //                    break;
    //                setInput(next.getFeatureMatrix());
    //                setLabels(next.getLabels());
    //                finetune();
    //            }
    }
    if (layerWiseConfigurations.isBackprop()) {
        update(TaskUtils.buildTask(iter));
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        while (iter.hasNext()) {
            DataSet next = iter.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null)
                break;
            boolean hasMaskArrays = next.hasMaskArrays();
            if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatureMatrix(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray());
            } else {
                if (hasMaskArrays)
                    setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray());
                setInput(next.getFeatureMatrix());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                }
                solver.optimize();
            }
            if (hasMaskArrays)
                clearLayerMaskArrays();
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    } else if (layerWiseConfigurations.isPretrain()) {
        log.warn("Warning: finetune is not applied.");
    }
    if (trainingListeners.size() > 0) {
        for (TrainingListener tl : trainingListeners) {
            tl.onEpochEnd(this);
        }
    }
}
Also used : Solver(org.deeplearning4j.optimize.Solver) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingListener(org.deeplearning4j.optimize.api.TrainingListener) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator)

Aggregations

DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)147 Test (org.junit.Test)133 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)90 DataSet (org.nd4j.linalg.dataset.DataSet)79 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)70 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)61 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)53 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)49 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)43 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)30 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)24 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)21 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)19 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)17 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)17 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)17 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)16 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)16 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)14