Search in sources :

Example 41 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSpark method testTimeTermination.

@Test
public void testTimeTermination() {
    //test termination after max time
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 15, 1, 0), esConf, net, irisData);
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;
    assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3);
    assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 9);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) DataSet(org.nd4j.linalg.dataset.DataSet) SparkEarlyStoppingTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingTrainer) SparkDataSetLossCalculator(org.deeplearning4j.spark.earlystopping.SparkDataSetLossCalculator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 42 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class IEvaluateFlatMapFunctionAdapter method call.

@Override
public Iterable<T> call(Iterator<DataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.emptyList();
    }
    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);
    List<DataSet> collect = new ArrayList<>();
    int totalCount = 0;
    while (dataSetIterator.hasNext()) {
        collect.clear();
        int nExamples = 0;
        while (dataSetIterator.hasNext() && nExamples < evalBatchSize) {
            DataSet next = dataSetIterator.next();
            nExamples += next.numExamples();
            collect.add(next);
        }
        totalCount += nExamples;
        DataSet data = DataSet.merge(collect);
        INDArray out;
        if (data.hasMaskArrays()) {
            out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray());
        } else {
            out = network.output(data.getFeatureMatrix(), false);
        }
        if (data.getLabels().rank() == 3) {
            if (data.getLabelsMaskArray() == null) {
                evaluation.evalTimeSeries(data.getLabels(), out);
            } else {
                evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray());
            }
        } else {
            evaluation.eval(data.getLabels(), out);
        }
    }
    if (log.isDebugEnabled()) {
        log.debug("Evaluated {} examples ", totalCount);
    }
    return Collections.singletonList(evaluation);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 43 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class ScoreExamplesWithKeyFunctionAdapter method call.

@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, DataSet>> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyList();
    }
    MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParameters(val);
    List<Tuple2<K, Double>> ret = new ArrayList<>();
    List<DataSet> collect = new ArrayList<>(batchSize);
    List<K> collectKey = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        collectKey.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            Tuple2<K, DataSet> t2 = iterator.next();
            DataSet ds = t2._2();
            int n = ds.numExamples();
            if (n != 1)
                throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
            collect.add(ds);
            collectKey.add(t2._1());
            nExamples += n;
        }
        totalCount += nExamples;
        DataSet data = DataSet.merge(collect);
        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();
        for (int i = 0; i < doubleScores.length; i++) {
            ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }
    return ret;
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 44 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getInitialModel.

@Override
public MultiLayerNetwork getInitialModel() {
    if (configuration.isCollectTrainingStats())
        stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueStart();
    NetBroadcastTuple tuple = broadcast.getValue();
    if (configuration.isCollectTrainingStats())
        stats.logBroadcastGetValueEnd();
    //Don't want to have shared configuration object: each may update its iteration count (for LR schedule etc) individually
    MultiLayerNetwork net = new MultiLayerNetwork(tuple.getConfiguration().clone());
    //Can't have shared parameter array across executors for parameter averaging, hence the 'true' for clone parameters array arg
    net.init(tuple.getParameters().unsafeDuplication(), false);
    if (tuple.getUpdaterState() != null) {
        //Can't have shared updater state
        net.setUpdater(new MultiLayerUpdater(net, tuple.getUpdaterState().unsafeDuplication()));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    configureListeners(net, tuple.getCounter().getAndIncrement());
    if (configuration.isCollectTrainingStats())
        stats.logInitEnd();
    return net;
}
Also used : MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) NetBroadcastTuple(org.deeplearning4j.spark.api.worker.NetBroadcastTuple) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 45 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class SparkDl4jMultiLayer method initNetwork.

private static MultiLayerNetwork initNetwork(MultiLayerConfiguration conf) {
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    return net;
}
Also used : MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)326 Test (org.junit.Test)277 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)206 INDArray (org.nd4j.linalg.api.ndarray.INDArray)166 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)111 DataSet (org.nd4j.linalg.dataset.DataSet)91 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)70 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)49 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)43 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)41 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)40 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)38 Random (java.util.Random)34 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)30 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)28 DL4JException (org.deeplearning4j.exception.DL4JException)20 Layer (org.deeplearning4j.nn.api.Layer)20 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)20 File (java.io.File)19 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)19