use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class TrainerTest method setup.
@Before
public void setup() {
super.setup();
this.srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
this.srw.setSquashingFunction(new ReLU<String>());
this.initTrainer();
query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
examples = new ArrayList<String>();
for (int k = 0; k < this.magicNumber; k++) {
for (int p = 0; p < this.magicNumber; p++) {
examples.add(new PosNegRWExample(brGraph, query, new int[] { nodes.getId("b" + k) }, new int[] { nodes.getId("r" + p) }).serialize());
}
}
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class Trainer method train.
public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs) {
ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
if (masterFeatures.size() > 0)
LearningGraphBuilder.setFeatures(masterFeatures);
NamedThreadFactory workingThreads = new NamedThreadFactory("work-");
NamedThreadFactory cleaningThreads = new NamedThreadFactory("cleanup-");
ThreadPoolExecutor workingPool;
ExecutorService cleanPool;
TrainingStatistics total = new TrainingStatistics();
StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
boolean graphSizesStatusLog = true;
StatusLogger stattime = new StatusLogger();
// repeat until ready to stop
while (!stopper.satisified()) {
// set up current epoch
this.epoch++;
for (SRW learner : this.learners.values()) {
learner.setEpoch(epoch);
learner.clearLoss();
}
log.info("epoch " + epoch + " ...");
status.tick();
// reset counters & file pointers
this.statistics = new TrainingStatistics();
workingThreads.reset();
cleaningThreads.reset();
workingPool = new ThreadPoolExecutor(this.nthreads, Integer.MAX_VALUE, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), workingThreads);
cleanPool = Executors.newSingleThreadExecutor(cleaningThreads);
// run examples
int id = 1;
stattime.start();
int countdown = -1;
Trainer notify = null;
for (String s : examples) {
if (log.isDebugEnabled())
log.debug("Queue size " + (workingPool.getTaskCount() - workingPool.getCompletedTaskCount()));
statistics.updateReadingStatistics(stattime.sinceLast());
/*
* Throttling behavior:
* Once the number of unfinished tasks exceeds 1.5x the number of threads,
* we add a 'notify' object to the next nthreads training tasks. Then, the
* master thread gathers 'notify' signals until the number of unfinished tasks
* is no longer greater than the number of threads. Then we start adding tasks again.
*
* This works more or less fine, since the master thread stops pulling examples
* from disk when there are then a maximum of 2.5x training examples in the queue (that's
* the original 1.5x, which could represent a maximum of 1.5x training examples,
* plus the nthreads training tasks with active 'notify' objects. There's an
* additional nthreads parsing tasks in the queue but those don't take up much
* memory so we don't care). This lets us read in a good-sized buffer without
* blowing up the heap.
*
* Worst-case: None of the backlog is cleared before the master thread enters
* the synchronized block. nthreads-1 threads will be training long jobs, and
* the one free thread works through the 0.5x backlog and all nthreads countdown
* examples. The notify() sent by the final countdown example will occur when
* there are nthreads unfinished tasks in the queue, and the master thread will exit
* the synchronized block and proceed.
*
* Best-case: The backlog is already cleared by the time the master thread enters
* the synchronized block. The while() loop immediately exits, and the notify()
* signals from the countdown examples have no effect.
*/
if (countdown > 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown);
countdown--;
} else if (countdown == 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown + "; throttling:");
countdown--;
notify = null;
try {
synchronized (this) {
if (log.isDebugEnabled())
log.debug("Clearing training queue...");
while (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > this.nthreads) this.wait();
if (log.isDebugEnabled())
log.debug("Queue cleared.");
}
} catch (InterruptedException e) {
e.printStackTrace();
}
} else if (workingPool.getTaskCount() - workingPool.getCompletedTaskCount() > 1.5 * this.nthreads) {
if (log.isDebugEnabled())
log.debug("Starting countdown");
countdown = this.nthreads;
notify = this;
}
Future<PosNegRWExample> parsed = workingPool.submit(new Parse(s, builder, id));
Future<ExampleStats> trained = workingPool.submit(new Train(parsed, paramVec, id, notify));
cleanPool.submit(new TraceLosses(trained, id));
id++;
stattime.tick();
if (log.isInfoEnabled() && status.due(1))
log.info("parsed: " + id + " trained: " + statistics.exampleSetSize);
}
cleanEpoch(workingPool, cleanPool, paramVec, stopper, id, total);
if (graphSizesStatusLog) {
log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
graphSizesStatusLog = false;
}
}
log.info("Reading statistics: min " + total.minReadTime + " / max " + total.maxReadTime + " / total " + total.readTime);
log.info("Parsing statistics: min " + total.minParseTime + " / max " + total.maxParseTime + " / total " + total.parseTime);
log.info("Training statistics: min " + total.minTrainTime + " / max " + total.maxTrainTime + " / total " + total.trainTime);
return paramVec;
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class SRW method gradient.
protected TIntDoubleMap gradient(ParamVector<String, ?> params, PosNegRWExample example) {
PosNegRWExample ex = (PosNegRWExample) example;
Set<String> features = this.regularizer.localFeatures(params, ex.getGraph());
TIntDoubleMap gradient = new TIntDoubleHashMap(features.size());
// add regularization term
regularization(params, ex, gradient);
int nonzero = lossf.computeLossGradient(params, example, gradient, this.cumloss, c);
for (int i : gradient.keys()) {
gradient.put(i, gradient.get(i) / example.length());
}
if (nonzero == 0) {
this.zeroGradientData.numZero++;
if (this.zeroGradientData.numZero < MAX_ZERO_LOGS) {
this.zeroGradientData.examples.append("\n").append(ex);
}
}
return gradient;
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class SRW method inference.
/** fills p, dp
* @param params */
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
PosNegRWExample ex = (PosNegRWExample) example;
ex.p = new double[ex.getGraph().node_hi];
ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
Arrays.fill(ex.p, 0.0);
// copy query into p
for (TIntDoubleIterator it = ex.getQueryVec().iterator(); it.hasNext(); ) {
it.advance();
ex.p[it.key()] = it.value();
}
for (int i = 0; i < c.apr.maxDepth; i++) {
if (log.isInfoEnabled() && status.due(3))
log.info("APR: iter " + (i + 1) + " of " + (c.apr.maxDepth));
inferenceUpdate(ex, status);
}
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class NormalizedPosLoss method computeLossGradient.
@Override
public int computeLossGradient(ParamVector params, PosNegRWExample example, TIntDoubleMap gradient, LossData lossdata, SRWOptions c) {
PosNegRWExample ex = (PosNegRWExample) example;
int nonzero = 0;
double mag = 0;
//is zero or 1, and the empirical loss gradient is zero.
if (ex.getNegList().length == 0 || ex.getPosList().length == 0)
return nonzero;
double sumPos = 0;
for (int a : ex.getPosList()) {
sumPos += clip(ex.p[a]);
}
sumPos = clip(sumPos);
for (int a : ex.getPosList()) {
for (TIntDoubleIterator da = ex.dp[a].iterator(); da.hasNext(); ) {
da.advance();
if (da.value() == 0)
continue;
nonzero++;
double aterm = -da.value() / sumPos;
gradient.adjustOrPutValue(da.key(), aterm, aterm);
}
}
lossdata.add(LOSS.LOG, -Math.log(sumPos));
double sumPosNeg = 0;
for (int pa : ex.getPosList()) {
sumPosNeg += clip(ex.p[pa]);
}
for (int pa : ex.getNegList()) {
sumPosNeg += clip(ex.p[pa]);
}
sumPosNeg = clip(sumPosNeg);
for (int a : ex.getPosList()) {
for (TIntDoubleIterator da = ex.dp[a].iterator(); da.hasNext(); ) {
da.advance();
if (da.value() == 0)
continue;
nonzero++;
double bterm = da.value() / sumPosNeg;
gradient.adjustOrPutValue(da.key(), bterm, bterm);
}
}
for (int b : ex.getNegList()) {
for (TIntDoubleIterator db = ex.dp[b].iterator(); db.hasNext(); ) {
db.advance();
if (db.value() == 0)
continue;
nonzero++;
double bterm = db.value() / sumPosNeg;
gradient.adjustOrPutValue(db.key(), bterm, bterm);
}
}
lossdata.add(LOSS.LOG, Math.log(sumPosNeg));
return nonzero;
}
Aggregations