Search in sources :

Example 6 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestParallelEarlyStopping method testBadTuning.

@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
    1.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(10, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(10)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 2, 2, 1);
    EarlyStoppingResult result = trainer.fit();
    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
    assertEquals(expDetails, result.getTerminationDetails());
    assertTrue(result.getBestModelEpoch() <= 0);
    assertNotNull(result.getBestModel());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 7 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class BaseEarlyStoppingTrainer method fit.

@Override
public EarlyStoppingResult<T> fit() {
    log.info("Starting early stopping training");
    if (esConfig.getScoreCalculator() == null)
        log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
    //Initialize termination conditions:
    if (esConfig.getIterationTerminationConditions() != null) {
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            c.initialize();
        }
    }
    if (esConfig.getEpochTerminationConditions() != null) {
        for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
            c.initialize();
        }
    }
    if (listener != null) {
        listener.onStart(esConfig, model);
    }
    Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
    int epochCount = 0;
    while (true) {
        reset();
        double lastScore;
        boolean terminate = false;
        IterationTerminationCondition terminationReason = null;
        int iterCount = 0;
        while (iterator.hasNext()) {
            try {
                if (train != null) {
                    fit((DataSet) iterator.next());
                } else
                    fit(trainMulti.next());
            } catch (Exception e) {
                log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", epochCount, iterCount, e);
                //Load best model to return
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            }
            //Check per-iteration termination conditions
            lastScore = model.score();
            for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
                if (c.terminate(lastScore)) {
                    terminate = true;
                    terminationReason = c;
                    break;
                }
            }
            if (terminate) {
                break;
            }
            iterCount++;
        }
        if (terminate) {
            //Handle termination condition:
            log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", epochCount, iterCount, terminationReason);
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, 0.0);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            if (listener != null) {
                listener.onCompletion(result);
            }
            return result;
        }
        log.info("Completed training epoch {}", epochCount);
        if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
            //Calculate score at this epoch:
            ScoreCalculator sc = esConfig.getScoreCalculator();
            double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(model));
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < bestModelScore) {
                //Save best model:
                if (bestModelEpoch == -1) {
                    //First calculated/reported score
                    log.info("Score at epoch {}: {}", epochCount, score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
                }
                bestModelScore = score;
                bestModelEpoch = epochCount;
                try {
                    esConfig.getModelSaver().saveBestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (listener != null) {
                listener.onEpoch(epochCount, score, esConfig, model);
            }
            //Check per-epoch termination conditions:
            boolean epochTerminate = false;
            EpochTerminationCondition termReason = null;
            for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
                if (c.terminate(epochCount, score)) {
                    epochTerminate = true;
                    termReason = c;
                    break;
                }
            }
            if (epochTerminate) {
                log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
                if (listener != null) {
                    listener.onCompletion(result);
                }
                return result;
            }
        }
        epochCount++;
    }
}
Also used : EpochTerminationCondition(org.deeplearning4j.earlystopping.termination.EpochTerminationCondition) IOException(java.io.IOException) IOException(java.io.IOException) LinkedHashMap(java.util.LinkedHashMap) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) IterationTerminationCondition(org.deeplearning4j.earlystopping.termination.IterationTerminationCondition) ScoreCalculator(org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)

Example 8 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testNoImprovementNEpochsTermination.

@Test
public void testNoImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
    //Simulate this by setting LR = 0.0
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)).iterationTerminationConditions(//Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    EarlyStoppingResult result = trainer.fit();
    //Expect no score change due to 0 LR -> terminate after 6 total epochs
    //Normally expect 6 epochs exactly; get a little more than that here due to rounding + order of operations
    assertTrue(result.getTotalEpochs() < 12);
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DataSet(org.nd4j.linalg.dataset.DataSet) ScoreImprovementEpochTerminationCondition(org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) Test(org.junit.Test)

Example 9 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testTimeTermination.

@Test
public void testTimeTermination() {
    //test termination after max time
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;
    assertTrue(durationSeconds >= 3);
    assertTrue(durationSeconds <= 9);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 10 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testBadTuning.

@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
    2.0).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    EarlyStoppingResult result = trainer.fit();
    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(7.5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Aggregations

EarlyStoppingResult (org.deeplearning4j.earlystopping.EarlyStoppingResult)10 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)7 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)7 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)7 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)7 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)7 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)7 Test (org.junit.Test)7 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)6 DataSet (org.nd4j.linalg.dataset.DataSet)6 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)5 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)4 IOException (java.io.IOException)3 ScoreCalculator (org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)3 EpochTerminationCondition (org.deeplearning4j.earlystopping.termination.EpochTerminationCondition)3 IterationTerminationCondition (org.deeplearning4j.earlystopping.termination.IterationTerminationCondition)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)3