Search in sources :

Example 1 with LossData

use of edu.cmu.ml.proppr.learn.tools.LossData in project ProPPR by TeamCohen.

the class Trainer method printLossOutput.

protected void printLossOutput(LossData lossThisEpoch) {
    System.out.print("avg training loss " + lossThisEpoch.total() + " on " + statistics.numExamplesThisEpoch + " examples");
    System.out.print(" =*:reg " + (lossThisEpoch.total() - lossThisEpoch.loss.get(LOSS.REGULARIZATION)));
    System.out.print(" : " + lossThisEpoch.loss.get(LOSS.REGULARIZATION));
    if (epoch > 1) {
        LossData diff = lossLastEpoch.diff(lossThisEpoch);
        System.out.println(" improved by " + diff.total() + " (*:reg " + (diff.total() - diff.loss.get(LOSS.REGULARIZATION)) + ":" + diff.loss.get(LOSS.REGULARIZATION) + ")");
        double percentImprovement = 100 * diff.total() / lossThisEpoch.total();
        System.out.println("pct reduction in training loss " + percentImprovement);
        // warn if there is a more than 1/2 of 1 percent increase in loss
        if (percentImprovement < -0.5) {
            System.out.println("WARNING: loss INCREASED by " + percentImprovement + " pct, i.e. total of " + (-diff.total()) + " - what's THAT about?");
        }
    } else
        System.out.println();
}
Also used : LossData(edu.cmu.ml.proppr.learn.tools.LossData)

Example 2 with LossData

use of edu.cmu.ml.proppr.learn.tools.LossData in project ProPPR by TeamCohen.

the class Trainer method cleanEpoch.

/**
	 * End-of-epoch cleanup routine shared by Trainer, CachingTrainer. 
	 * Shuts down working thread, cleaning thread, regularizer, loss calculations, stopper calculations, 
	 * training statistics, and zero gradient statistics.
	 * @param workingPool
	 * @param cleanPool
	 * @param paramVec
	 * @param traceLosses
	 * @param stopper
	 * @param n - number of examples
	 * @param stats
	 */
protected void cleanEpoch(ExecutorService workingPool, ExecutorService cleanPool, ParamVector<String, ?> paramVec, StoppingCriterion stopper, int n, TrainingStatistics stats) {
    n = n - 1;
    workingPool.shutdown();
    try {
        workingPool.awaitTermination(7, TimeUnit.DAYS);
        cleanPool.shutdown();
        cleanPool.awaitTermination(7, TimeUnit.DAYS);
    } catch (InterruptedException e) {
        e.printStackTrace();
    }
    // finish any trailing updates for this epoch
    // finish any trailing updates for this epoch
    this.masterLearner.cleanupParams(paramVec, paramVec);
    // loss status and signalling the stopper
    LossData lossThisEpoch = new LossData();
    for (SRW learner : this.learners.values()) {
        lossThisEpoch.add(learner.cumulativeLoss());
    }
    lossThisEpoch.convertCumulativesToAverage(statistics.numExamplesThisEpoch);
    printLossOutput(lossThisEpoch);
    if (epoch > 1) {
        stopper.recordConsecutiveLosses(lossThisEpoch, lossLastEpoch);
    }
    lossLastEpoch = lossThisEpoch;
    ZeroGradientData zeros = this.masterLearner.new ZeroGradientData();
    for (SRW learner : this.learners.values()) {
        zeros.add(learner.getZeroGradientData());
    }
    if (zeros.numZero > 0) {
        log.info(zeros.numZero + " / " + n + " examples with 0 gradient");
        if (zeros.numZero / (float) n > MAX_PCT_ZERO_GRADIENT)
            log.warn("Having this many 0 gradients is unusual for supervised tasks. Try a different squashing function?");
    }
    stopper.recordEpoch();
    statistics.checkStatistics();
    stats.updateReadingStatistics(statistics.readTime);
    stats.updateParsingStatistics(statistics.parseTime);
    stats.updateTrainingStatistics(statistics.trainTime);
}
Also used : LossData(edu.cmu.ml.proppr.learn.tools.LossData) ZeroGradientData(edu.cmu.ml.proppr.learn.SRW.ZeroGradientData) SRW(edu.cmu.ml.proppr.learn.SRW)

Example 3 with LossData

use of edu.cmu.ml.proppr.learn.tools.LossData in project ProPPR by TeamCohen.

the class DprSRW method init.

private void init(double istayProb) {
    //set walk parameters here
    stayProb = istayProb;
    this.cumloss = new LossData();
}
Also used : LossData(edu.cmu.ml.proppr.learn.tools.LossData)

Aggregations

LossData (edu.cmu.ml.proppr.learn.tools.LossData)3 SRW (edu.cmu.ml.proppr.learn.SRW)1 ZeroGradientData (edu.cmu.ml.proppr.learn.SRW.ZeroGradientData)1