use of edu.cmu.ml.proppr.learn.SRW.ZeroGradientData in project ProPPR by TeamCohen.
the class Trainer method cleanEpoch.
/**
* End-of-epoch cleanup routine shared by Trainer, CachingTrainer.
* Shuts down working thread, cleaning thread, regularizer, loss calculations, stopper calculations,
* training statistics, and zero gradient statistics.
* @param workingPool
* @param cleanPool
* @param paramVec
* @param traceLosses
* @param stopper
* @param n - number of examples
* @param stats
*/
protected void cleanEpoch(ExecutorService workingPool, ExecutorService cleanPool, ParamVector<String, ?> paramVec, StoppingCriterion stopper, int n, TrainingStatistics stats) {
n = n - 1;
workingPool.shutdown();
try {
workingPool.awaitTermination(7, TimeUnit.DAYS);
cleanPool.shutdown();
cleanPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
e.printStackTrace();
}
// finish any trailing updates for this epoch
// finish any trailing updates for this epoch
this.masterLearner.cleanupParams(paramVec, paramVec);
// loss status and signalling the stopper
LossData lossThisEpoch = new LossData();
for (SRW learner : this.learners.values()) {
lossThisEpoch.add(learner.cumulativeLoss());
}
lossThisEpoch.convertCumulativesToAverage(statistics.numExamplesThisEpoch);
printLossOutput(lossThisEpoch);
if (epoch > 1) {
stopper.recordConsecutiveLosses(lossThisEpoch, lossLastEpoch);
}
lossLastEpoch = lossThisEpoch;
ZeroGradientData zeros = this.masterLearner.new ZeroGradientData();
for (SRW learner : this.learners.values()) {
zeros.add(learner.getZeroGradientData());
}
if (zeros.numZero > 0) {
log.info(zeros.numZero + " / " + n + " examples with 0 gradient");
if (zeros.numZero / (float) n > MAX_PCT_ZERO_GRADIENT)
log.warn("Having this many 0 gradients is unusual for supervised tasks. Try a different squashing function?");
}
stopper.recordEpoch();
statistics.checkStatistics();
stats.updateReadingStatistics(statistics.readTime);
stats.updateParsingStatistics(statistics.parseTime);
stats.updateTrainingStatistics(statistics.trainTime);
}
Aggregations