Search in sources :

Example 1 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSpark method testBadTuning.

@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
    10.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 4, 1, 0), esConf, net, irisData);
    EarlyStoppingResult result = trainer.fit();
    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(7.5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) DataSet(org.nd4j.linalg.dataset.DataSet) SparkEarlyStoppingTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingTrainer) SparkDataSetLossCalculator(org.deeplearning4j.spark.earlystopping.SparkDataSetLossCalculator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 2 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSpark method testNoImprovementNEpochsTermination.

@Test
public void testNoImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
    //Simulate this by setting LR = 0.0
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)).iterationTerminationConditions(//Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 10, 1, 0), esConf, net, irisData);
    EarlyStoppingResult result = trainer.fit();
    //Expect no score change due to 0 LR -> terminate after 6 total epochs
    //Normally expect 6 epochs exactly; get a little more than that here due to rounding + order of operations
    assertTrue(result.getTotalEpochs() < 12);
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) DataSet(org.nd4j.linalg.dataset.DataSet) ScoreImprovementEpochTerminationCondition(org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition) SparkEarlyStoppingTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingTrainer) SparkDataSetLossCalculator(org.deeplearning4j.spark.earlystopping.SparkDataSetLossCalculator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 3 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSpark method testTimeTermination.

@Test
public void testTimeTermination() {
    //test termination after max time
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 15, 1, 0), esConf, net, irisData);
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;
    assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3);
    assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 9);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) DataSet(org.nd4j.linalg.dataset.DataSet) SparkEarlyStoppingTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingTrainer) SparkDataSetLossCalculator(org.deeplearning4j.spark.earlystopping.SparkDataSetLossCalculator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 4 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class BaseSparkEarlyStoppingTrainer method fit.

@Override
public EarlyStoppingResult<T> fit() {
    log.info("Starting early stopping training");
    if (esConfig.getScoreCalculator() == null)
        log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
    //Initialize termination conditions:
    if (esConfig.getIterationTerminationConditions() != null) {
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            c.initialize();
        }
    }
    if (esConfig.getEpochTerminationConditions() != null) {
        for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
            c.initialize();
        }
    }
    if (listener != null)
        listener.onStart(esConfig, net);
    Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
    int epochCount = 0;
    while (true) {
        //Iterate (do epochs) until termination condition hit
        double lastScore;
        boolean terminate = false;
        IterationTerminationCondition terminationReason = null;
        if (train != null)
            fit(train);
        else
            fitMulti(trainMulti);
        //TODO revisit per iteration termination conditions, ensuring they are evaluated *per averaging* not per epoch
        //Check per-iteration termination conditions
        lastScore = getScore();
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            if (c.terminate(lastScore)) {
                terminate = true;
                terminationReason = c;
                break;
            }
        }
        if (terminate) {
            //Handle termination condition:
            log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", epochCount, epochCount, terminationReason);
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(net, 0.0);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            if (listener != null)
                listener.onCompletion(result);
            return result;
        }
        log.info("Completed training epoch {}", epochCount);
        if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
            //Calculate score at this epoch:
            ScoreCalculator sc = esConfig.getScoreCalculator();
            double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(net));
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < bestModelScore) {
                //Save best model:
                if (bestModelEpoch == -1) {
                    //First calculated/reported score
                    log.info("Score at epoch {}: {}", epochCount, score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
                }
                bestModelScore = score;
                bestModelEpoch = epochCount;
                try {
                    esConfig.getModelSaver().saveBestModel(net, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(net, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (listener != null)
                listener.onEpoch(epochCount, score, esConfig, net);
            //Check per-epoch termination conditions:
            boolean epochTerminate = false;
            EpochTerminationCondition termReason = null;
            for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
                if (c.terminate(epochCount, score)) {
                    epochTerminate = true;
                    termReason = c;
                    break;
                }
            }
            if (epochTerminate) {
                log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
                if (listener != null)
                    listener.onCompletion(result);
                return result;
            }
            epochCount++;
        }
    }
}
Also used : EpochTerminationCondition(org.deeplearning4j.earlystopping.termination.EpochTerminationCondition) IOException(java.io.IOException) LinkedHashMap(java.util.LinkedHashMap) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) IterationTerminationCondition(org.deeplearning4j.earlystopping.termination.IterationTerminationCondition) ScoreCalculator(org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)

Example 5 with EarlyStoppingResult

use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.

the class EarlyStoppingParallelTrainer method fit.

@Override
public EarlyStoppingResult<T> fit() {
    log.info("Starting early stopping training");
    if (wrapper == null) {
        throw new IllegalStateException("Trainer has already exhausted it's parallel wrapper instance. Please instantiate a new trainer.");
    }
    if (esConfig.getScoreCalculator() == null)
        log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
    //Initialize termination conditions:
    if (esConfig.getIterationTerminationConditions() != null) {
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            c.initialize();
        }
    }
    if (esConfig.getEpochTerminationConditions() != null) {
        for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
            c.initialize();
        }
    }
    if (listener != null) {
        listener.onStart(esConfig, model);
    }
    Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
    // append the iteration listener
    int epochCount = 0;
    // iterate through epochs
    while (true) {
        // note that we don't call train.reset() because ParallelWrapper does it already
        try {
            if (train != null) {
                wrapper.fit(train);
            } else
                wrapper.fit(trainMulti);
        } catch (Exception e) {
            log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", epochCount, iterCount, e);
            //Load best model to return
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
        }
        if (terminate.get()) {
            //Handle termination condition:
            log.info("Hit per iteration termination condition at epoch {}, iteration {}. Reason: {}", epochCount, iterCount, terminationReason);
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, 0.0);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            if (bestModel == null) {
                //Could occur with very early termination
                bestModel = model;
            }
            EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            if (listener != null) {
                listener.onCompletion(result);
            }
            // clean up
            wrapper.shutdown();
            this.wrapper = null;
            return result;
        }
        log.info("Completed training epoch {}", epochCount);
        if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
            //Calculate score at this epoch:
            ScoreCalculator sc = esConfig.getScoreCalculator();
            double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(model));
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < bestModelScore) {
                //Save best model:
                if (bestModelEpoch == -1) {
                    //First calculated/reported score
                    log.info("Score at epoch {}: {}", epochCount, score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
                }
                bestModelScore = score;
                bestModelEpoch = epochCount;
                try {
                    esConfig.getModelSaver().saveBestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (listener != null) {
                listener.onEpoch(epochCount, score, esConfig, model);
            }
            //Check per-epoch termination conditions:
            boolean epochTerminate = false;
            EpochTerminationCondition termReason = null;
            for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
                if (c.terminate(epochCount, score)) {
                    epochTerminate = true;
                    termReason = c;
                    wrapper.stopFit();
                    break;
                }
            }
            if (epochTerminate) {
                log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
                if (listener != null) {
                    listener.onCompletion(result);
                }
                // clean up
                wrapper.shutdown();
                this.wrapper = null;
                return result;
            }
        }
        epochCount++;
    }
}
Also used : EpochTerminationCondition(org.deeplearning4j.earlystopping.termination.EpochTerminationCondition) IOException(java.io.IOException) AtomicDouble(com.google.common.util.concurrent.AtomicDouble) IOException(java.io.IOException) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) IterationTerminationCondition(org.deeplearning4j.earlystopping.termination.IterationTerminationCondition) ScoreCalculator(org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)

Aggregations

EarlyStoppingResult (org.deeplearning4j.earlystopping.EarlyStoppingResult)10 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)7 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)7 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)7 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)7 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)7 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)7 Test (org.junit.Test)7 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)6 DataSet (org.nd4j.linalg.dataset.DataSet)6 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)5 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)4 IOException (java.io.IOException)3 ScoreCalculator (org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)3 EpochTerminationCondition (org.deeplearning4j.earlystopping.termination.EpochTerminationCondition)3 IterationTerminationCondition (org.deeplearning4j.earlystopping.termination.IterationTerminationCondition)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)3