Search in sources :

Example 1 with ScoreCalculator

use of org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator in project deeplearning4j by deeplearning4j.

the class BaseEarlyStoppingTrainer method fit.

@Override
public EarlyStoppingResult<T> fit() {
    log.info("Starting early stopping training");
    if (esConfig.getScoreCalculator() == null)
        log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
    //Initialize termination conditions:
    if (esConfig.getIterationTerminationConditions() != null) {
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            c.initialize();
        }
    }
    if (esConfig.getEpochTerminationConditions() != null) {
        for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
            c.initialize();
        }
    }
    if (listener != null) {
        listener.onStart(esConfig, model);
    }
    Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
    int epochCount = 0;
    while (true) {
        reset();
        double lastScore;
        boolean terminate = false;
        IterationTerminationCondition terminationReason = null;
        int iterCount = 0;
        while (iterator.hasNext()) {
            try {
                if (train != null) {
                    fit((DataSet) iterator.next());
                } else
                    fit(trainMulti.next());
            } catch (Exception e) {
                log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", epochCount, iterCount, e);
                //Load best model to return
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            }
            //Check per-iteration termination conditions
            lastScore = model.score();
            for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
                if (c.terminate(lastScore)) {
                    terminate = true;
                    terminationReason = c;
                    break;
                }
            }
            if (terminate) {
                break;
            }
            iterCount++;
        }
        if (terminate) {
            //Handle termination condition:
            log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", epochCount, iterCount, terminationReason);
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, 0.0);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            if (listener != null) {
                listener.onCompletion(result);
            }
            return result;
        }
        log.info("Completed training epoch {}", epochCount);
        if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
            //Calculate score at this epoch:
            ScoreCalculator sc = esConfig.getScoreCalculator();
            double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(model));
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < bestModelScore) {
                //Save best model:
                if (bestModelEpoch == -1) {
                    //First calculated/reported score
                    log.info("Score at epoch {}: {}", epochCount, score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
                }
                bestModelScore = score;
                bestModelEpoch = epochCount;
                try {
                    esConfig.getModelSaver().saveBestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (listener != null) {
                listener.onEpoch(epochCount, score, esConfig, model);
            }
            //Check per-epoch termination conditions:
            boolean epochTerminate = false;
            EpochTerminationCondition termReason = null;
            for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
                if (c.terminate(epochCount, score)) {
                    epochTerminate = true;
                    termReason = c;
                    break;
                }
            }
            if (epochTerminate) {
                log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
                if (listener != null) {
                    listener.onCompletion(result);
                }
                return result;
            }
        }
        epochCount++;
    }
}
Also used : EpochTerminationCondition(org.deeplearning4j.earlystopping.termination.EpochTerminationCondition) IOException(java.io.IOException) IOException(java.io.IOException) LinkedHashMap(java.util.LinkedHashMap) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) IterationTerminationCondition(org.deeplearning4j.earlystopping.termination.IterationTerminationCondition) ScoreCalculator(org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)

Example 2 with ScoreCalculator

use of org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator in project deeplearning4j by deeplearning4j.

the class EarlyStoppingParallelTrainer method fit.

@Override
public EarlyStoppingResult<T> fit() {
    log.info("Starting early stopping training");
    if (wrapper == null) {
        throw new IllegalStateException("Trainer has already exhausted it's parallel wrapper instance. Please instantiate a new trainer.");
    }
    if (esConfig.getScoreCalculator() == null)
        log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
    //Initialize termination conditions:
    if (esConfig.getIterationTerminationConditions() != null) {
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            c.initialize();
        }
    }
    if (esConfig.getEpochTerminationConditions() != null) {
        for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
            c.initialize();
        }
    }
    if (listener != null) {
        listener.onStart(esConfig, model);
    }
    Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
    // append the iteration listener
    int epochCount = 0;
    // iterate through epochs
    while (true) {
        // note that we don't call train.reset() because ParallelWrapper does it already
        try {
            if (train != null) {
                wrapper.fit(train);
            } else
                wrapper.fit(trainMulti);
        } catch (Exception e) {
            log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", epochCount, iterCount, e);
            //Load best model to return
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
        }
        if (terminate.get()) {
            //Handle termination condition:
            log.info("Hit per iteration termination condition at epoch {}, iteration {}. Reason: {}", epochCount, iterCount, terminationReason);
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, 0.0);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            if (bestModel == null) {
                //Could occur with very early termination
                bestModel = model;
            }
            EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            if (listener != null) {
                listener.onCompletion(result);
            }
            // clean up
            wrapper.shutdown();
            this.wrapper = null;
            return result;
        }
        log.info("Completed training epoch {}", epochCount);
        if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
            //Calculate score at this epoch:
            ScoreCalculator sc = esConfig.getScoreCalculator();
            double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(model));
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < bestModelScore) {
                //Save best model:
                if (bestModelEpoch == -1) {
                    //First calculated/reported score
                    log.info("Score at epoch {}: {}", epochCount, score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
                }
                bestModelScore = score;
                bestModelEpoch = epochCount;
                try {
                    esConfig.getModelSaver().saveBestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(model, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (listener != null) {
                listener.onEpoch(epochCount, score, esConfig, model);
            }
            //Check per-epoch termination conditions:
            boolean epochTerminate = false;
            EpochTerminationCondition termReason = null;
            for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
                if (c.terminate(epochCount, score)) {
                    epochTerminate = true;
                    termReason = c;
                    wrapper.stopFit();
                    break;
                }
            }
            if (epochTerminate) {
                log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
                if (listener != null) {
                    listener.onCompletion(result);
                }
                // clean up
                wrapper.shutdown();
                this.wrapper = null;
                return result;
            }
        }
        epochCount++;
    }
}
Also used : EpochTerminationCondition(org.deeplearning4j.earlystopping.termination.EpochTerminationCondition) IOException(java.io.IOException) AtomicDouble(com.google.common.util.concurrent.AtomicDouble) IOException(java.io.IOException) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) IterationTerminationCondition(org.deeplearning4j.earlystopping.termination.IterationTerminationCondition) ScoreCalculator(org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)

Example 3 with ScoreCalculator

use of org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator in project deeplearning4j by deeplearning4j.

the class BaseSparkEarlyStoppingTrainer method fit.

@Override
public EarlyStoppingResult<T> fit() {
    log.info("Starting early stopping training");
    if (esConfig.getScoreCalculator() == null)
        log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
    //Initialize termination conditions:
    if (esConfig.getIterationTerminationConditions() != null) {
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            c.initialize();
        }
    }
    if (esConfig.getEpochTerminationConditions() != null) {
        for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
            c.initialize();
        }
    }
    if (listener != null)
        listener.onStart(esConfig, net);
    Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
    int epochCount = 0;
    while (true) {
        //Iterate (do epochs) until termination condition hit
        double lastScore;
        boolean terminate = false;
        IterationTerminationCondition terminationReason = null;
        if (train != null)
            fit(train);
        else
            fitMulti(trainMulti);
        //TODO revisit per iteration termination conditions, ensuring they are evaluated *per averaging* not per epoch
        //Check per-iteration termination conditions
        lastScore = getScore();
        for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
            if (c.terminate(lastScore)) {
                terminate = true;
                terminationReason = c;
                break;
            }
        }
        if (terminate) {
            //Handle termination condition:
            log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", epochCount, epochCount, terminationReason);
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(net, 0.0);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            T bestModel;
            try {
                bestModel = esConfig.getModelSaver().getBestModel();
            } catch (IOException e2) {
                throw new RuntimeException(e2);
            }
            EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
            if (listener != null)
                listener.onCompletion(result);
            return result;
        }
        log.info("Completed training epoch {}", epochCount);
        if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
            //Calculate score at this epoch:
            ScoreCalculator sc = esConfig.getScoreCalculator();
            double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(net));
            scoreVsEpoch.put(epochCount - 1, score);
            if (sc != null && score < bestModelScore) {
                //Save best model:
                if (bestModelEpoch == -1) {
                    //First calculated/reported score
                    log.info("Score at epoch {}: {}", epochCount, score);
                } else {
                    log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
                }
                bestModelScore = score;
                bestModelEpoch = epochCount;
                try {
                    esConfig.getModelSaver().saveBestModel(net, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving best model", e);
                }
            }
            if (esConfig.isSaveLastModel()) {
                //Save last model:
                try {
                    esConfig.getModelSaver().saveLatestModel(net, score);
                } catch (IOException e) {
                    throw new RuntimeException("Error saving most recent model", e);
                }
            }
            if (listener != null)
                listener.onEpoch(epochCount, score, esConfig, net);
            //Check per-epoch termination conditions:
            boolean epochTerminate = false;
            EpochTerminationCondition termReason = null;
            for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
                if (c.terminate(epochCount, score)) {
                    epochTerminate = true;
                    termReason = c;
                    break;
                }
            }
            if (epochTerminate) {
                log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
                T bestModel;
                try {
                    bestModel = esConfig.getModelSaver().getBestModel();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
                EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
                if (listener != null)
                    listener.onCompletion(result);
                return result;
            }
            epochCount++;
        }
    }
}
Also used : EpochTerminationCondition(org.deeplearning4j.earlystopping.termination.EpochTerminationCondition) IOException(java.io.IOException) LinkedHashMap(java.util.LinkedHashMap) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) IterationTerminationCondition(org.deeplearning4j.earlystopping.termination.IterationTerminationCondition) ScoreCalculator(org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)

Aggregations

IOException (java.io.IOException)3 EarlyStoppingResult (org.deeplearning4j.earlystopping.EarlyStoppingResult)3 ScoreCalculator (org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator)3 EpochTerminationCondition (org.deeplearning4j.earlystopping.termination.EpochTerminationCondition)3 IterationTerminationCondition (org.deeplearning4j.earlystopping.termination.IterationTerminationCondition)3 LinkedHashMap (java.util.LinkedHashMap)2 AtomicDouble (com.google.common.util.concurrent.AtomicDouble)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1