use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class CachingTrainer method trainCached.
public ParamVector<String, ?> trainCached(List<PosNegRWExample> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs, TrainingStatistics total) {
ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
NamedThreadFactory trainThreads = new NamedThreadFactory("work-");
ExecutorService trainPool;
ExecutorService cleanPool;
StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
boolean graphSizesStatusLog = true;
// repeat until ready to stop
while (!stopper.satisified()) {
// set up current epoch
this.epoch++;
for (SRW learner : this.learners.values()) {
learner.setEpoch(epoch);
learner.clearLoss();
}
log.info("epoch " + epoch + " ...");
status.tick();
// reset counters & file pointers
this.statistics = new TrainingStatistics();
trainThreads.reset();
trainPool = Executors.newFixedThreadPool(this.nthreads, trainThreads);
cleanPool = Executors.newSingleThreadExecutor();
// run examples
int id = 1;
if (this.shuffle)
Collections.shuffle(examples);
for (PosNegRWExample s : examples) {
Future<ExampleStats> trained = trainPool.submit(new Train(new PretendParse(s), paramVec, id, null));
cleanPool.submit(new TraceLosses(trained, id));
id++;
if (log.isInfoEnabled() && status.due(1))
log.info("queued: " + id + " trained: " + statistics.exampleSetSize);
}
cleanEpoch(trainPool, cleanPool, paramVec, stopper, id, total);
if (graphSizesStatusLog) {
log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
graphSizesStatusLog = false;
}
}
log.info("Reading: " + total.readTime + " Parsing: " + total.parseTime + " Training: " + total.trainTime);
return paramVec;
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class Diagnostic method main.
public static void main(String[] args) {
StatusLogger status = new StatusLogger();
try {
int inputFiles = Configuration.USE_TRAIN;
int outputFiles = 0;
int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
int modules = Configuration.USE_SRW;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
log.info("Parsing " + groundedFile + "...");
long start = System.currentTimeMillis();
final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
final SRW srw = c.srw;
final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
srw.setEpoch(1);
srw.clearLoss();
srw.fixedWeightRules().addExact("id(restart)");
srw.fixedWeightRules().addExact("id(trueLoop)");
srw.fixedWeightRules().addExact("id(trueLoopRestart)");
/* DiagSrwES: */
ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
int i = 1;
for (String s : new ParsedFile(groundedFile)) {
final int id = i++;
final String in = s;
parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start " + id);
PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
log.debug("Parsing done " + id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #" + id);
e.printStackTrace();
}
return null;
}
}));
}
parserPool.shutdown();
i = 1;
for (Future<PosNegRWExample> future : parsed) {
final int id = i++;
final Future<PosNegRWExample> in = future;
trainerPool.submit(new Runnable() {
@Override
public void run() {
try {
PosNegRWExample ex = in.get();
log.debug("Training start " + id);
srw.trainOnExample(params, ex, status);
log.debug("Training done " + id);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
});
}
trainerPool.shutdown();
try {
parserPool.awaitTermination(7, TimeUnit.DAYS);
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
/* /SrwES */
/* SrwTtwop:
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads/2, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
@Override
public void run() {
try {
final PosNegRWExample ex = in.get();
log.debug("Cleanup start "+id);
trainerPool.submit(new Runnable() {
@Override
public void run(){
log.debug("Training start "+id);
srw.trainOnExample(params,ex);
log.debug("Training done "+id);
}
});
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
trainerPool.shutdown();
try {
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?",e);
}
/SrwTtwop */
/* all diag tasks except SrwO:
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
log.debug("Training start "+id);
srw.trainOnExample(params,ret);
log.debug("Training done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
/* SrwO:
Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()),
new Transformer<PosNegRWExample,Integer>() {
@Override
public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
return new Callable<Integer>() {
@Override
public Integer call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Training start "+id);
srw.trainOnExample(params,in);
log.debug("Training done "+id);
//log.debug("Job done "+id);
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return in.length();
}};
}}, new Cleanup<Integer>() {
@Override
public Runnable cleanup(final Future<Integer> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
srw.cleanupParams(params, params);
log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class Trainer method findGradient.
public ParamVector<String, ?> findGradient(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> paramVec) {
log.info("Computing gradient on cooked examples...");
ParamVector<String, ?> sumGradient = new SimpleParamVector<String>();
if (paramVec == null) {
paramVec = createParamVector();
}
paramVec = this.masterLearner.setupParams(paramVec);
if (masterFeatures != null && masterFeatures.size() > 0)
LearningGraphBuilder.setFeatures(masterFeatures);
//
// //WW: accumulate example-size normalized gradient
// for (PosNegRWExample x : examples) {
//// this.learner.initializeFeatures(paramVec,x.getGraph());
// this.learner.accumulateGradient(paramVec, x, sumGradient);
// k++;
// }
NamedThreadFactory workThreads = new NamedThreadFactory("work-");
ExecutorService workPool, cleanPool;
workPool = Executors.newFixedThreadPool(this.nthreads, workThreads);
cleanPool = Executors.newSingleThreadExecutor();
// run examples
int id = 1;
int countdown = -1;
Trainer notify = null;
status.start();
for (String s : examples) {
if (log.isInfoEnabled() && status.due())
log.info(id + " examples read...");
long queueSize = (((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount());
if (log.isDebugEnabled())
log.debug("Queue size " + queueSize);
if (countdown > 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown);
countdown--;
} else if (countdown == 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown + "; throttling:");
countdown--;
notify = null;
try {
synchronized (this) {
if (log.isDebugEnabled())
log.debug("Clearing training queue...");
while ((((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount()) > this.nthreads) this.wait();
if (log.isDebugEnabled())
log.debug("Queue cleared.");
}
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else if (queueSize > 1.5 * this.nthreads) {
if (log.isDebugEnabled())
log.debug("Starting countdown");
countdown = this.nthreads;
notify = this;
}
Future<PosNegRWExample> parsed = workPool.submit(new Parse(s, builder, id));
Future<ExampleStats> gradfound = workPool.submit(new Grad(parsed, paramVec, sumGradient, id, notify));
cleanPool.submit(new TraceLosses(gradfound, id));
id++;
}
workPool.shutdown();
try {
workPool.awaitTermination(7, TimeUnit.DAYS);
cleanPool.shutdown();
cleanPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
this.masterLearner.cleanupParams(paramVec, sumGradient);
//WW: renormalize by the total number of queries
for (Iterator<String> it = sumGradient.keySet().iterator(); it.hasNext(); ) {
String feature = it.next();
double unnormf = sumGradient.get(feature);
// query count stored in numExamplesThisEpoch, as noted above
double norm = unnormf / this.statistics.numExamplesThisEpoch;
sumGradient.put(feature, norm);
}
return sumGradient;
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class PosNegLoss method computeLossGradient.
@Override
public int computeLossGradient(ParamVector params, PosNegRWExample example, TIntDoubleMap gradient, LossData lossdata, SRWOptions c) {
PosNegRWExample ex = (PosNegRWExample) example;
int nonzero = 0;
// add empirical loss gradient term
// positive examples
double pmax = 0;
for (int a : ex.getPosList()) {
double pa = clip(ex.p[a]);
if (pa > pmax)
pmax = pa;
for (TIntDoubleIterator da = ex.dp[a].iterator(); da.hasNext(); ) {
da.advance();
if (da.value() == 0)
continue;
nonzero++;
double aterm = -da.value() / pa;
gradient.adjustOrPutValue(da.key(), aterm, aterm);
}
if (log.isDebugEnabled())
log.debug("+p=" + pa);
lossdata.add(LOSS.LOG, -Math.log(pa));
}
//negative instance booster
double h = pmax + c.delta;
double beta = 1;
if (c.delta < 0.5)
beta = (Math.log(1 / h)) / (Math.log(1 / (1 - h)));
// negative examples
for (int b : ex.getNegList()) {
double pb = clip(ex.p[b]);
for (TIntDoubleIterator db = ex.dp[b].iterator(); db.hasNext(); ) {
db.advance();
if (db.value() == 0)
continue;
nonzero++;
double bterm = beta * db.value() / (1 - pb);
gradient.adjustOrPutValue(db.key(), bterm, bterm);
}
if (log.isDebugEnabled())
log.debug("-p=" + pb);
lossdata.add(LOSS.LOG, -Math.log(1.0 - pb));
}
return nonzero;
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class SRWTest method testLearn1.
/**
* check that learning on red/blue graph works
*/
@Test
public void testLearn1() {
TIntDoubleMap query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
int[] pos = new int[blues.size()];
{
int i = 0;
for (String k : blues) pos[i++] = nodes.getId(k);
}
int[] neg = new int[reds.size()];
{
int i = 0;
for (String k : reds) neg[i++] = nodes.getId(k);
}
PosNegRWExample example = factory.makeExample("learn1", brGraph, query, pos, neg);
// ParamVector weightVec = new SimpleParamVector();
// weightVec.put("fromb",1.01);
// weightVec.put("tob",1.0);
// weightVec.put("fromr",1.03);
// weightVec.put("tor",1.0);
// weightVec.put("id(restart)",1.02);
ParamVector<String, ?> trainedParams = uniformParams.copy();
double preLoss = makeLoss(trainedParams, example);
srw.clearLoss();
srw.trainOnExample(trainedParams, example, new StatusLogger());
double postLoss = makeLoss(trainedParams, example);
assertTrue(String.format("preloss %f >=? postloss %f", preLoss, postLoss), preLoss == 0 || preLoss > postLoss);
}
Aggregations