Search in sources :

Example 1 with ZeroGradientData

use of edu.cmu.ml.proppr.learn.SRW.ZeroGradientData in project ProPPR by TeamCohen.

the class Trainer method cleanEpoch.

/**
	 * End-of-epoch cleanup routine shared by Trainer, CachingTrainer. 
	 * Shuts down working thread, cleaning thread, regularizer, loss calculations, stopper calculations, 
	 * training statistics, and zero gradient statistics.
	 * @param workingPool
	 * @param cleanPool
	 * @param paramVec
	 * @param traceLosses
	 * @param stopper
	 * @param n - number of examples
	 * @param stats
	 */
protected void cleanEpoch(ExecutorService workingPool, ExecutorService cleanPool, ParamVector<String, ?> paramVec, StoppingCriterion stopper, int n, TrainingStatistics stats) {
    n = n - 1;
    workingPool.shutdown();
    try {
        workingPool.awaitTermination(7, TimeUnit.DAYS);
        cleanPool.shutdown();
        cleanPool.awaitTermination(7, TimeUnit.DAYS);
    } catch (InterruptedException e) {
        e.printStackTrace();
    }
    // finish any trailing updates for this epoch
    // finish any trailing updates for this epoch
    this.masterLearner.cleanupParams(paramVec, paramVec);
    // loss status and signalling the stopper
    LossData lossThisEpoch = new LossData();
    for (SRW learner : this.learners.values()) {
        lossThisEpoch.add(learner.cumulativeLoss());
    }
    lossThisEpoch.convertCumulativesToAverage(statistics.numExamplesThisEpoch);
    printLossOutput(lossThisEpoch);
    if (epoch > 1) {
        stopper.recordConsecutiveLosses(lossThisEpoch, lossLastEpoch);
    }
    lossLastEpoch = lossThisEpoch;
    ZeroGradientData zeros = this.masterLearner.new ZeroGradientData();
    for (SRW learner : this.learners.values()) {
        zeros.add(learner.getZeroGradientData());
    }
    if (zeros.numZero > 0) {
        log.info(zeros.numZero + " / " + n + " examples with 0 gradient");
        if (zeros.numZero / (float) n > MAX_PCT_ZERO_GRADIENT)
            log.warn("Having this many 0 gradients is unusual for supervised tasks. Try a different squashing function?");
    }
    stopper.recordEpoch();
    statistics.checkStatistics();
    stats.updateReadingStatistics(statistics.readTime);
    stats.updateParsingStatistics(statistics.parseTime);
    stats.updateTrainingStatistics(statistics.trainTime);
}
Also used : LossData(edu.cmu.ml.proppr.learn.tools.LossData) ZeroGradientData(edu.cmu.ml.proppr.learn.SRW.ZeroGradientData) SRW(edu.cmu.ml.proppr.learn.SRW)

Aggregations

SRW (edu.cmu.ml.proppr.learn.SRW)1 ZeroGradientData (edu.cmu.ml.proppr.learn.SRW.ZeroGradientData)1 LossData (edu.cmu.ml.proppr.learn.tools.LossData)1