Search in sources :

Example 6 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class TrainerTest method setup.

@Before
public void setup() {
    super.setup();
    this.srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
    this.srw.setSquashingFunction(new ReLU<String>());
    this.initTrainer();
    query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    examples = new ArrayList<String>();
    for (int k = 0; k < this.magicNumber; k++) {
        for (int p = 0; p < this.magicNumber; p++) {
            examples.add(new PosNegRWExample(brGraph, query, new int[] { nodes.getId("b" + k) }, new int[] { nodes.getId("r" + p) }).serialize());
        }
    }
}
Also used : RegularizationSchedule(edu.cmu.ml.proppr.learn.RegularizationSchedule) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RegularizeL2(edu.cmu.ml.proppr.learn.RegularizeL2) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) SRW(edu.cmu.ml.proppr.learn.SRW) Before(org.junit.Before)

Example 7 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class Trainer method train.

public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs) {
    ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
    if (masterFeatures.size() > 0)
        LearningGraphBuilder.setFeatures(masterFeatures);
    NamedThreadFactory workingThreads = new NamedThreadFactory("work-");
    NamedThreadFactory cleaningThreads = new NamedThreadFactory("cleanup-");
    ThreadPoolExecutor workingPool;
    ExecutorService cleanPool;
    TrainingStatistics total = new TrainingStatistics();
    StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
    boolean graphSizesStatusLog = true;
    StatusLogger stattime = new StatusLogger();
    // repeat until ready to stop
    while (!stopper.satisified()) {
        // set up current epoch
        this.epoch++;
        for (SRW learner : this.learners.values()) {
            learner.setEpoch(epoch);
            learner.clearLoss();
        }
        log.info("epoch " + epoch + " ...");
        status.tick();
        // reset counters & file pointers
        this.statistics = new TrainingStatistics();
        workingThreads.reset();
        cleaningThreads.reset();
        workingPool = new ThreadPoolExecutor(this.nthreads, Integer.MAX_VALUE, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), workingThreads);
        cleanPool = Executors.newSingleThreadExecutor(cleaningThreads);
        // run examples
        int id = 1;
        stattime.start();
        int countdown = -1;
        Trainer notify = null;
        for (String s : examples) {
            if (log.isDebugEnabled())
                log.debug("Queue size " + (workingPool.getTaskCount() - workingPool.getCompletedTaskCount()));
            statistics.updateReadingStatistics(stattime.sinceLast());
            /*
				 * Throttling behavior:
				 * Once the number of unfinished tasks exceeds 1.5x the number of threads,
				 * we add a 'notify' object to the next nthreads training tasks. Then, the
				 * master thread gathers 'notify' signals until the number of unfinished tasks 
				 * is no longer greater than the number of threads. Then we start adding tasks again.
				 * 
				 * This works more or less fine, since the master thread stops pulling examples
				 * from disk when there are then a maximum of 2.5x training examples in the queue (that's
				 * the original 1.5x, which could represent a maximum of 1.5x training examples,
				 * plus the nthreads training tasks with active 'notify' objects. There's an 
				 * additional nthreads parsing tasks in the queue but those don't take up much 
				 * memory so we don't care). This lets us read in a good-sized buffer without
				 * blowing up the heap.
				 * 
				 * Worst-case: None of the backlog is cleared before the master thread enters
				 * the synchronized block. nthreads-1 threads will be training long jobs, and 
				 * the one free thread works through the 0.5x backlog and all nthreads countdown 
				 * examples. The notify() sent by the final countdown example will occur when 
				 * there are nthreads unfinished tasks in the queue, and the master thread will exit
				 * the synchronized block and proceed.
				 * 
				 * Best-case: The backlog is already cleared by the time the master thread enters
				 * the synchronized block. The while() loop immediately exits, and the notify()
				 * signals from the countdown examples have no effect.
				 */
            if (countdown > 0) {
                if (log.isDebugEnabled())
                    log.debug("Countdown " + countdown);
                countdown--;
            } else if (countdown == 0) {
                if (log.isDebugEnabled())
                    log.debug("Countdown " + countdown + "; throttling:");
                countdown--;
                notify = null;
                try {
                    synchronized (this) {
                        if (log.isDebugEnabled())
                            log.debug("Clearing training queue...");
                        while (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > this.nthreads) this.wait();
                        if (log.isDebugEnabled())
                            log.debug("Queue cleared.");
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            } else if (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > 1.5 * this.nthreads) {
                if (log.isDebugEnabled())
                    log.debug("Starting countdown");
                countdown = this.nthreads;
                notify = this;
            }
            Future<PosNegRWExample> parsed = workingPool.submit(new Parse(s, builder, id));
            Future<ExampleStats> trained = workingPool.submit(new Train(parsed, paramVec, id, notify));
            cleanPool.submit(new TraceLosses(trained, id));
            id++;
            stattime.tick();
            if (log.isInfoEnabled() && status.due(1))
                log.info("parsed: " + id + " trained: " + statistics.exampleSetSize);
        }
        cleanEpoch(workingPool, cleanPool, paramVec, stopper, id, total);
        if (graphSizesStatusLog) {
            log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
            graphSizesStatusLog = false;
        }
    }
    log.info("Reading  statistics: min " + total.minReadTime + " / max " + total.maxReadTime + " / total " + total.readTime);
    log.info("Parsing  statistics: min " + total.minParseTime + " / max " + total.maxParseTime + " / total " + total.parseTime);
    log.info("Training statistics: min " + total.minTrainTime + " / max " + total.maxTrainTime + " / total " + total.trainTime);
    return paramVec;
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) NamedThreadFactory(edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory) StoppingCriterion(edu.cmu.ml.proppr.learn.tools.StoppingCriterion) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) ExecutorService(java.util.concurrent.ExecutorService) SRW(edu.cmu.ml.proppr.learn.SRW) ThreadPoolExecutor(java.util.concurrent.ThreadPoolExecutor)

Example 8 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class SRW method gradient.

protected TIntDoubleMap gradient(ParamVector<String, ?> params, PosNegRWExample example) {
    PosNegRWExample ex = (PosNegRWExample) example;
    Set<String> features = this.regularizer.localFeatures(params, ex.getGraph());
    TIntDoubleMap gradient = new TIntDoubleHashMap(features.size());
    // add regularization term
    regularization(params, ex, gradient);
    int nonzero = lossf.computeLossGradient(params, example, gradient, this.cumloss, c);
    for (int i : gradient.keys()) {
        gradient.put(i, gradient.get(i) / example.length());
    }
    if (nonzero == 0) {
        this.zeroGradientData.numZero++;
        if (this.zeroGradientData.numZero < MAX_ZERO_LOGS) {
            this.zeroGradientData.examples.append("\n").append(ex);
        }
    }
    return gradient;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Example 9 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class SRW method inference.

/** fills p, dp 
	 * @param params */
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
    PosNegRWExample ex = (PosNegRWExample) example;
    ex.p = new double[ex.getGraph().node_hi];
    ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
    Arrays.fill(ex.p, 0.0);
    // copy query into p
    for (TIntDoubleIterator it = ex.getQueryVec().iterator(); it.hasNext(); ) {
        it.advance();
        ex.p[it.key()] = it.value();
    }
    for (int i = 0; i < c.apr.maxDepth; i++) {
        if (log.isInfoEnabled() && status.due(3))
            log.info("APR: iter " + (i + 1) + " of " + (c.apr.maxDepth));
        inferenceUpdate(ex, status);
    }
}
Also used : PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Example 10 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class NormalizedPosLoss method computeLossGradient.

@Override
public int computeLossGradient(ParamVector params, PosNegRWExample example, TIntDoubleMap gradient, LossData lossdata, SRWOptions c) {
    PosNegRWExample ex = (PosNegRWExample) example;
    int nonzero = 0;
    double mag = 0;
    //is zero or 1, and the empirical loss gradient is zero.
    if (ex.getNegList().length == 0 || ex.getPosList().length == 0)
        return nonzero;
    double sumPos = 0;
    for (int a : ex.getPosList()) {
        sumPos += clip(ex.p[a]);
    }
    sumPos = clip(sumPos);
    for (int a : ex.getPosList()) {
        for (TIntDoubleIterator da = ex.dp[a].iterator(); da.hasNext(); ) {
            da.advance();
            if (da.value() == 0)
                continue;
            nonzero++;
            double aterm = -da.value() / sumPos;
            gradient.adjustOrPutValue(da.key(), aterm, aterm);
        }
    }
    lossdata.add(LOSS.LOG, -Math.log(sumPos));
    double sumPosNeg = 0;
    for (int pa : ex.getPosList()) {
        sumPosNeg += clip(ex.p[pa]);
    }
    for (int pa : ex.getNegList()) {
        sumPosNeg += clip(ex.p[pa]);
    }
    sumPosNeg = clip(sumPosNeg);
    for (int a : ex.getPosList()) {
        for (TIntDoubleIterator da = ex.dp[a].iterator(); da.hasNext(); ) {
            da.advance();
            if (da.value() == 0)
                continue;
            nonzero++;
            double bterm = da.value() / sumPosNeg;
            gradient.adjustOrPutValue(da.key(), bterm, bterm);
        }
    }
    for (int b : ex.getNegList()) {
        for (TIntDoubleIterator db = ex.dp[b].iterator(); db.hasNext(); ) {
            db.advance();
            if (db.value() == 0)
                continue;
            nonzero++;
            double bterm = db.value() / sumPosNeg;
            gradient.adjustOrPutValue(db.key(), bterm, bterm);
        }
    }
    lossdata.add(LOSS.LOG, Math.log(sumPosNeg));
    return nonzero;
}
Also used : PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Aggregations

PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)12 SRW (edu.cmu.ml.proppr.learn.SRW)4 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)4 ExecutorService (java.util.concurrent.ExecutorService)4 NamedThreadFactory (edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory)3 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)3 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)3 RWExampleParser (edu.cmu.ml.proppr.learn.tools.RWExampleParser)2 StoppingCriterion (edu.cmu.ml.proppr.learn.tools.StoppingCriterion)2 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)2 ArrayList (java.util.ArrayList)2 ThreadPoolExecutor (java.util.concurrent.ThreadPoolExecutor)2 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)1 GraphFormatException (edu.cmu.ml.proppr.graph.GraphFormatException)1 RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)1 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)1 ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)1 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 Callable (java.util.concurrent.Callable)1