Search in sources :

Example 1 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class CachingTrainer method trainCached.

public ParamVector<String, ?> trainCached(List<PosNegRWExample> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs, TrainingStatistics total) {
    ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
    NamedThreadFactory trainThreads = new NamedThreadFactory("work-");
    ExecutorService trainPool;
    ExecutorService cleanPool;
    StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
    boolean graphSizesStatusLog = true;
    // repeat until ready to stop
    while (!stopper.satisified()) {
        // set up current epoch
        this.epoch++;
        for (SRW learner : this.learners.values()) {
            learner.setEpoch(epoch);
            learner.clearLoss();
        }
        log.info("epoch " + epoch + " ...");
        status.tick();
        // reset counters & file pointers
        this.statistics = new TrainingStatistics();
        trainThreads.reset();
        trainPool = Executors.newFixedThreadPool(this.nthreads, trainThreads);
        cleanPool = Executors.newSingleThreadExecutor();
        // run examples
        int id = 1;
        if (this.shuffle)
            Collections.shuffle(examples);
        for (PosNegRWExample s : examples) {
            Future<ExampleStats> trained = trainPool.submit(new Train(new PretendParse(s), paramVec, id, null));
            cleanPool.submit(new TraceLosses(trained, id));
            id++;
            if (log.isInfoEnabled() && status.due(1))
                log.info("queued: " + id + " trained: " + statistics.exampleSetSize);
        }
        cleanEpoch(trainPool, cleanPool, paramVec, stopper, id, total);
        if (graphSizesStatusLog) {
            log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
            graphSizesStatusLog = false;
        }
    }
    log.info("Reading: " + total.readTime + " Parsing: " + total.parseTime + " Training: " + total.trainTime);
    return paramVec;
}
Also used : NamedThreadFactory(edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory) StoppingCriterion(edu.cmu.ml.proppr.learn.tools.StoppingCriterion) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) SRW(edu.cmu.ml.proppr.learn.SRW)

Example 2 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class Diagnostic method main.

public static void main(String[] args) {
    StatusLogger status = new StatusLogger();
    try {
        int inputFiles = Configuration.USE_TRAIN;
        int outputFiles = 0;
        int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
        int modules = Configuration.USE_SRW;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        log.info("Parsing " + groundedFile + "...");
        long start = System.currentTimeMillis();
        final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
        final SRW srw = c.srw;
        final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
        srw.setEpoch(1);
        srw.clearLoss();
        srw.fixedWeightRules().addExact("id(restart)");
        srw.fixedWeightRules().addExact("id(trueLoop)");
        srw.fixedWeightRules().addExact("id(trueLoopRestart)");
        /* DiagSrwES: */
        ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
        final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        int i = 1;
        for (String s : new ParsedFile(groundedFile)) {
            final int id = i++;
            final String in = s;
            parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {

                @Override
                public PosNegRWExample call() throws Exception {
                    try {
                        //log.debug("Job start "+id);
                        //PosNegRWExample ret = parser.parse(in, b.copy());
                        log.debug("Parsing start " + id);
                        PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
                        log.debug("Parsing done " + id);
                        //log.debug("Job done "+id);
                        return ret;
                    } catch (IllegalArgumentException e) {
                        System.err.println("Problem with #" + id);
                        e.printStackTrace();
                    }
                    return null;
                }
            }));
        }
        parserPool.shutdown();
        i = 1;
        for (Future<PosNegRWExample> future : parsed) {
            final int id = i++;
            final Future<PosNegRWExample> in = future;
            trainerPool.submit(new Runnable() {

                @Override
                public void run() {
                    try {
                        PosNegRWExample ex = in.get();
                        log.debug("Training start " + id);
                        srw.trainOnExample(params, ex, status);
                        log.debug("Training done " + id);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        trainerPool.shutdown();
        try {
            parserPool.awaitTermination(7, TimeUnit.DAYS);
            trainerPool.awaitTermination(7, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            log.error("Interrupted?", e);
        }
        /* /SrwES */
        /* SrwTtwop: 
			final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads/2, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									@Override
									public void run() {
										try {
											final PosNegRWExample ex = in.get();
											log.debug("Cleanup start "+id);
											trainerPool.submit(new Runnable() {
													@Override
													public void run(){
														log.debug("Training start "+id);
														srw.trainOnExample(params,ex);
														log.debug("Training done "+id);
													}
												});
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			trainerPool.shutdown();
			try {
				trainerPool.awaitTermination(7, TimeUnit.DAYS);
			} catch (InterruptedException e) {
				log.error("Interrupted?",e);
			}

			 /SrwTtwop */
        /* all diag tasks except SrwO: 
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									log.debug("Training start "+id);
									srw.trainOnExample(params,ret);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        /* SrwO:
			   Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
			m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()), 
						 new Transformer<PosNegRWExample,Integer>() {
						@Override
						public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
							return new Callable<Integer>() {
								@Override
								public Integer call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Training start "+id);
									srw.trainOnExample(params,in);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return in.length();
								}};
						}}, new Cleanup<Integer>() {
							@Override
							public Runnable cleanup(final Future<Integer> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        srw.cleanupParams(params, params);
        log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ArrayList(java.util.ArrayList) Callable(java.util.concurrent.Callable) SRW(edu.cmu.ml.proppr.learn.SRW) RWExampleParser(edu.cmu.ml.proppr.learn.tools.RWExampleParser) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ExecutionException(java.util.concurrent.ExecutionException) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 3 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class Trainer method findGradient.

public ParamVector<String, ?> findGradient(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> paramVec) {
    log.info("Computing gradient on cooked examples...");
    ParamVector<String, ?> sumGradient = new SimpleParamVector<String>();
    if (paramVec == null) {
        paramVec = createParamVector();
    }
    paramVec = this.masterLearner.setupParams(paramVec);
    if (masterFeatures != null && masterFeatures.size() > 0)
        LearningGraphBuilder.setFeatures(masterFeatures);
    //		
    //		//WW: accumulate example-size normalized gradient
    //		for (PosNegRWExample x : examples) {
    ////			this.learner.initializeFeatures(paramVec,x.getGraph());
    //			this.learner.accumulateGradient(paramVec, x, sumGradient);
    //			k++;
    //		}
    NamedThreadFactory workThreads = new NamedThreadFactory("work-");
    ExecutorService workPool, cleanPool;
    workPool = Executors.newFixedThreadPool(this.nthreads, workThreads);
    cleanPool = Executors.newSingleThreadExecutor();
    // run examples
    int id = 1;
    int countdown = -1;
    Trainer notify = null;
    status.start();
    for (String s : examples) {
        if (log.isInfoEnabled() && status.due())
            log.info(id + " examples read...");
        long queueSize = (((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount());
        if (log.isDebugEnabled())
            log.debug("Queue size " + queueSize);
        if (countdown > 0) {
            if (log.isDebugEnabled())
                log.debug("Countdown " + countdown);
            countdown--;
        } else if (countdown == 0) {
            if (log.isDebugEnabled())
                log.debug("Countdown " + countdown + "; throttling:");
            countdown--;
            notify = null;
            try {
                synchronized (this) {
                    if (log.isDebugEnabled())
                        log.debug("Clearing training queue...");
                    while ((((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount()) > this.nthreads) this.wait();
                    if (log.isDebugEnabled())
                        log.debug("Queue cleared.");
                }
            } catch (InterruptedException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        } else if (queueSize > 1.5 * this.nthreads) {
            if (log.isDebugEnabled())
                log.debug("Starting countdown");
            countdown = this.nthreads;
            notify = this;
        }
        Future<PosNegRWExample> parsed = workPool.submit(new Parse(s, builder, id));
        Future<ExampleStats> gradfound = workPool.submit(new Grad(parsed, paramVec, sumGradient, id, notify));
        cleanPool.submit(new TraceLosses(gradfound, id));
        id++;
    }
    workPool.shutdown();
    try {
        workPool.awaitTermination(7, TimeUnit.DAYS);
        cleanPool.shutdown();
        cleanPool.awaitTermination(7, TimeUnit.DAYS);
    } catch (InterruptedException e) {
        log.error("Interrupted?", e);
    }
    this.masterLearner.cleanupParams(paramVec, sumGradient);
    //WW: renormalize by the total number of queries
    for (Iterator<String> it = sumGradient.keySet().iterator(); it.hasNext(); ) {
        String feature = it.next();
        double unnormf = sumGradient.get(feature);
        // query count stored in numExamplesThisEpoch, as noted above
        double norm = unnormf / this.statistics.numExamplesThisEpoch;
        sumGradient.put(feature, norm);
    }
    return sumGradient;
}
Also used : NamedThreadFactory(edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) ExecutorService(java.util.concurrent.ExecutorService) ThreadPoolExecutor(java.util.concurrent.ThreadPoolExecutor)

Example 4 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class PosNegLoss method computeLossGradient.

@Override
public int computeLossGradient(ParamVector params, PosNegRWExample example, TIntDoubleMap gradient, LossData lossdata, SRWOptions c) {
    PosNegRWExample ex = (PosNegRWExample) example;
    int nonzero = 0;
    // add empirical loss gradient term
    // positive examples
    double pmax = 0;
    for (int a : ex.getPosList()) {
        double pa = clip(ex.p[a]);
        if (pa > pmax)
            pmax = pa;
        for (TIntDoubleIterator da = ex.dp[a].iterator(); da.hasNext(); ) {
            da.advance();
            if (da.value() == 0)
                continue;
            nonzero++;
            double aterm = -da.value() / pa;
            gradient.adjustOrPutValue(da.key(), aterm, aterm);
        }
        if (log.isDebugEnabled())
            log.debug("+p=" + pa);
        lossdata.add(LOSS.LOG, -Math.log(pa));
    }
    //negative instance booster
    double h = pmax + c.delta;
    double beta = 1;
    if (c.delta < 0.5)
        beta = (Math.log(1 / h)) / (Math.log(1 / (1 - h)));
    // negative examples
    for (int b : ex.getNegList()) {
        double pb = clip(ex.p[b]);
        for (TIntDoubleIterator db = ex.dp[b].iterator(); db.hasNext(); ) {
            db.advance();
            if (db.value() == 0)
                continue;
            nonzero++;
            double bterm = beta * db.value() / (1 - pb);
            gradient.adjustOrPutValue(db.key(), bterm, bterm);
        }
        if (log.isDebugEnabled())
            log.debug("-p=" + pb);
        lossdata.add(LOSS.LOG, -Math.log(1.0 - pb));
    }
    return nonzero;
}
Also used : PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Example 5 with PosNegRWExample

use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.

the class SRWTest method testLearn1.

/**
	 * check that learning on red/blue graph works
	 */
@Test
public void testLearn1() {
    TIntDoubleMap query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    int[] pos = new int[blues.size()];
    {
        int i = 0;
        for (String k : blues) pos[i++] = nodes.getId(k);
    }
    int[] neg = new int[reds.size()];
    {
        int i = 0;
        for (String k : reds) neg[i++] = nodes.getId(k);
    }
    PosNegRWExample example = factory.makeExample("learn1", brGraph, query, pos, neg);
    //		ParamVector weightVec = new SimpleParamVector();
    //		weightVec.put("fromb",1.01);
    //		weightVec.put("tob",1.0);
    //		weightVec.put("fromr",1.03);
    //		weightVec.put("tor",1.0);
    //		weightVec.put("id(restart)",1.02);
    ParamVector<String, ?> trainedParams = uniformParams.copy();
    double preLoss = makeLoss(trainedParams, example);
    srw.clearLoss();
    srw.trainOnExample(trainedParams, example, new StatusLogger());
    double postLoss = makeLoss(trainedParams, example);
    assertTrue(String.format("preloss %f >=? postloss %f", preLoss, postLoss), preLoss == 0 || preLoss > postLoss);
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) Test(org.junit.Test)

Aggregations

PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)12 SRW (edu.cmu.ml.proppr.learn.SRW)4 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)4 ExecutorService (java.util.concurrent.ExecutorService)4 NamedThreadFactory (edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory)3 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)3 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)3 RWExampleParser (edu.cmu.ml.proppr.learn.tools.RWExampleParser)2 StoppingCriterion (edu.cmu.ml.proppr.learn.tools.StoppingCriterion)2 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)2 ArrayList (java.util.ArrayList)2 ThreadPoolExecutor (java.util.concurrent.ThreadPoolExecutor)2 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)1 GraphFormatException (edu.cmu.ml.proppr.graph.GraphFormatException)1 RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)1 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)1 ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)1 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 Callable (java.util.concurrent.Callable)1