use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method testBadTuning.
@Test
public void testBadTuning() {
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
10.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 4, 1, 0), esConf, net, irisData);
EarlyStoppingResult result = trainer.fit();
assertTrue(result.getTotalEpochs() < 5);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
String expDetails = new MaxScoreIterationTerminationCondition(7.5).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method testNoImprovementNEpochsTermination.
@Test
public void testNoImprovementNEpochsTermination() {
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
//Simulate this by setting LR = 0.0
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)).iterationTerminationConditions(//Initial score is ~2.5
new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 10, 1, 0), esConf, net, irisData);
EarlyStoppingResult result = trainer.fit();
//Expect no score change due to 0 LR -> terminate after 6 total epochs
//Normally expect 6 epochs exactly; get a little more than that here due to rounding + order of operations
assertTrue(result.getTotalEpochs() < 12);
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method testTimeTermination.
@Test
public void testTimeTermination() {
//test termination after max time
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster(true, 4, 1, 150 / 15, 1, 0), esConf, net, irisData);
long startTime = System.currentTimeMillis();
EarlyStoppingResult result = trainer.fit();
long endTime = System.currentTimeMillis();
int durationSeconds = (int) (endTime - startTime) / 1000;
assertTrue("durationSeconds = " + durationSeconds, durationSeconds >= 3);
assertTrue("durationSeconds = " + durationSeconds, durationSeconds <= 9);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.
the class BaseSparkEarlyStoppingTrainer method fit.
@Override
public EarlyStoppingResult<T> fit() {
log.info("Starting early stopping training");
if (esConfig.getScoreCalculator() == null)
log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
//Initialize termination conditions:
if (esConfig.getIterationTerminationConditions() != null) {
for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
c.initialize();
}
}
if (esConfig.getEpochTerminationConditions() != null) {
for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
c.initialize();
}
}
if (listener != null)
listener.onStart(esConfig, net);
Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
int epochCount = 0;
while (true) {
//Iterate (do epochs) until termination condition hit
double lastScore;
boolean terminate = false;
IterationTerminationCondition terminationReason = null;
if (train != null)
fit(train);
else
fitMulti(trainMulti);
//TODO revisit per iteration termination conditions, ensuring they are evaluated *per averaging* not per epoch
//Check per-iteration termination conditions
lastScore = getScore();
for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
if (c.terminate(lastScore)) {
terminate = true;
terminationReason = c;
break;
}
}
if (terminate) {
//Handle termination condition:
log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", epochCount, epochCount, terminationReason);
if (esConfig.isSaveLastModel()) {
//Save last model:
try {
esConfig.getModelSaver().saveLatestModel(net, 0.0);
} catch (IOException e) {
throw new RuntimeException("Error saving most recent model", e);
}
}
T bestModel;
try {
bestModel = esConfig.getModelSaver().getBestModel();
} catch (IOException e2) {
throw new RuntimeException(e2);
}
EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
if (listener != null)
listener.onCompletion(result);
return result;
}
log.info("Completed training epoch {}", epochCount);
if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
//Calculate score at this epoch:
ScoreCalculator sc = esConfig.getScoreCalculator();
double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(net));
scoreVsEpoch.put(epochCount - 1, score);
if (sc != null && score < bestModelScore) {
//Save best model:
if (bestModelEpoch == -1) {
//First calculated/reported score
log.info("Score at epoch {}: {}", epochCount, score);
} else {
log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
}
bestModelScore = score;
bestModelEpoch = epochCount;
try {
esConfig.getModelSaver().saveBestModel(net, score);
} catch (IOException e) {
throw new RuntimeException("Error saving best model", e);
}
}
if (esConfig.isSaveLastModel()) {
//Save last model:
try {
esConfig.getModelSaver().saveLatestModel(net, score);
} catch (IOException e) {
throw new RuntimeException("Error saving most recent model", e);
}
}
if (listener != null)
listener.onEpoch(epochCount, score, esConfig, net);
//Check per-epoch termination conditions:
boolean epochTerminate = false;
EpochTerminationCondition termReason = null;
for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
if (c.terminate(epochCount, score)) {
epochTerminate = true;
termReason = c;
break;
}
}
if (epochTerminate) {
log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
T bestModel;
try {
bestModel = esConfig.getModelSaver().getBestModel();
} catch (IOException e2) {
throw new RuntimeException(e2);
}
EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
if (listener != null)
listener.onCompletion(result);
return result;
}
epochCount++;
}
}
}
use of org.deeplearning4j.earlystopping.EarlyStoppingResult in project deeplearning4j by deeplearning4j.
the class EarlyStoppingParallelTrainer method fit.
@Override
public EarlyStoppingResult<T> fit() {
log.info("Starting early stopping training");
if (wrapper == null) {
throw new IllegalStateException("Trainer has already exhausted it's parallel wrapper instance. Please instantiate a new trainer.");
}
if (esConfig.getScoreCalculator() == null)
log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
//Initialize termination conditions:
if (esConfig.getIterationTerminationConditions() != null) {
for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
c.initialize();
}
}
if (esConfig.getEpochTerminationConditions() != null) {
for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
c.initialize();
}
}
if (listener != null) {
listener.onStart(esConfig, model);
}
Map<Integer, Double> scoreVsEpoch = new LinkedHashMap<>();
// append the iteration listener
int epochCount = 0;
// iterate through epochs
while (true) {
// note that we don't call train.reset() because ParallelWrapper does it already
try {
if (train != null) {
wrapper.fit(train);
} else
wrapper.fit(trainMulti);
} catch (Exception e) {
log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", epochCount, iterCount, e);
//Load best model to return
T bestModel;
try {
bestModel = esConfig.getModelSaver().getBestModel();
} catch (IOException e2) {
throw new RuntimeException(e2);
}
return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
}
if (terminate.get()) {
//Handle termination condition:
log.info("Hit per iteration termination condition at epoch {}, iteration {}. Reason: {}", epochCount, iterCount, terminationReason);
if (esConfig.isSaveLastModel()) {
//Save last model:
try {
esConfig.getModelSaver().saveLatestModel(model, 0.0);
} catch (IOException e) {
throw new RuntimeException("Error saving most recent model", e);
}
}
T bestModel;
try {
bestModel = esConfig.getModelSaver().getBestModel();
} catch (IOException e2) {
throw new RuntimeException(e2);
}
if (bestModel == null) {
//Could occur with very early termination
bestModel = model;
}
EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel);
if (listener != null) {
listener.onCompletion(result);
}
// clean up
wrapper.shutdown();
this.wrapper = null;
return result;
}
log.info("Completed training epoch {}", epochCount);
if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) {
//Calculate score at this epoch:
ScoreCalculator sc = esConfig.getScoreCalculator();
double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(model));
scoreVsEpoch.put(epochCount - 1, score);
if (sc != null && score < bestModelScore) {
//Save best model:
if (bestModelEpoch == -1) {
//First calculated/reported score
log.info("Score at epoch {}: {}", epochCount, score);
} else {
log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch);
}
bestModelScore = score;
bestModelEpoch = epochCount;
try {
esConfig.getModelSaver().saveBestModel(model, score);
} catch (IOException e) {
throw new RuntimeException("Error saving best model", e);
}
}
if (esConfig.isSaveLastModel()) {
//Save last model:
try {
esConfig.getModelSaver().saveLatestModel(model, score);
} catch (IOException e) {
throw new RuntimeException("Error saving most recent model", e);
}
}
if (listener != null) {
listener.onEpoch(epochCount, score, esConfig, model);
}
//Check per-epoch termination conditions:
boolean epochTerminate = false;
EpochTerminationCondition termReason = null;
for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) {
if (c.terminate(epochCount, score)) {
epochTerminate = true;
termReason = c;
wrapper.stopFit();
break;
}
}
if (epochTerminate) {
log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
T bestModel;
try {
bestModel = esConfig.getModelSaver().getBestModel();
} catch (IOException e2) {
throw new RuntimeException(e2);
}
EarlyStoppingResult<T> result = new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel);
if (listener != null) {
listener.onCompletion(result);
}
// clean up
wrapper.shutdown();
this.wrapper = null;
return result;
}
}
epochCount++;
}
}
Aggregations