Search in sources :

Example 1 with SRW

use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.

the class GradientFinderTest method setup.

@Before
public void setup() {
    super.setup();
    this.srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
    this.srw.setSquashingFunction(new ReLU<String>());
    this.initTrainer();
    query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    examples = new ArrayList<String>();
    for (int k = 0; k < this.magicNumber; k++) {
        for (int p = 0; p < this.magicNumber; p++) {
            StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
            nodes.getId("r0")).append("\t").append(// pos
            nodes.getId("b" + k)).append("\t").append(//neg
            nodes.getId("r" + p)).append("\t").append(// nodes
            brGraph.nodeSize()).append("\t").append(//edges
            brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
            "\t");
            int labelDependencies = 0;
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
                if (i > 0)
                    sb.append(":");
                sb.append(brGraph.featureLibrary.getSymbol(i + 1));
            }
            for (int u = 0; u < brGraph.node_hi; u++) {
                HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
                for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
                    int v = brGraph.edge_dest[ec];
                    sb.append("\t").append(u).append("->").append(v).append(":");
                    for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
                        outgoingFeatures.add(brGraph.label_feature_id[lc]);
                        if (lc > brGraph.edge_labels_lo[ec])
                            sb.append(",");
                        sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
                    }
                }
                labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
            }
            serialized.append(labelDependencies).append("\t").append(sb);
            examples.add(serialized.toString());
        }
    }
}
Also used : RegularizationSchedule(edu.cmu.ml.proppr.learn.RegularizationSchedule) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) RegularizeL2(edu.cmu.ml.proppr.learn.RegularizeL2) SRW(edu.cmu.ml.proppr.learn.SRW) HashSet(java.util.HashSet) Before(org.junit.Before)

Example 2 with SRW

use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.

the class SRWTest method initSrw.

public void initSrw() {
    srw = new SRW();
    this.srw.setRegularizer(new RegularizationSchedule(this.srw, new Regularize()));
    srw.c.apr.maxDepth = 10;
}
Also used : SRW(edu.cmu.ml.proppr.learn.SRW)

Example 3 with SRW

use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.

the class CachingTrainer method trainCached.

public ParamVector<String, ?> trainCached(List<PosNegRWExample> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs, TrainingStatistics total) {
    ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
    NamedThreadFactory trainThreads = new NamedThreadFactory("work-");
    ExecutorService trainPool;
    ExecutorService cleanPool;
    StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
    boolean graphSizesStatusLog = true;
    // repeat until ready to stop
    while (!stopper.satisified()) {
        // set up current epoch
        this.epoch++;
        for (SRW learner : this.learners.values()) {
            learner.setEpoch(epoch);
            learner.clearLoss();
        }
        log.info("epoch " + epoch + " ...");
        status.tick();
        // reset counters & file pointers
        this.statistics = new TrainingStatistics();
        trainThreads.reset();
        trainPool = Executors.newFixedThreadPool(this.nthreads, trainThreads);
        cleanPool = Executors.newSingleThreadExecutor();
        // run examples
        int id = 1;
        if (this.shuffle)
            Collections.shuffle(examples);
        for (PosNegRWExample s : examples) {
            Future<ExampleStats> trained = trainPool.submit(new Train(new PretendParse(s), paramVec, id, null));
            cleanPool.submit(new TraceLosses(trained, id));
            id++;
            if (log.isInfoEnabled() && status.due(1))
                log.info("queued: " + id + " trained: " + statistics.exampleSetSize);
        }
        cleanEpoch(trainPool, cleanPool, paramVec, stopper, id, total);
        if (graphSizesStatusLog) {
            log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
            graphSizesStatusLog = false;
        }
    }
    log.info("Reading: " + total.readTime + " Parsing: " + total.parseTime + " Training: " + total.trainTime);
    return paramVec;
}
Also used : NamedThreadFactory(edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory) StoppingCriterion(edu.cmu.ml.proppr.learn.tools.StoppingCriterion) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) SRW(edu.cmu.ml.proppr.learn.SRW)

Example 4 with SRW

use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.

the class Diagnostic method main.

public static void main(String[] args) {
    StatusLogger status = new StatusLogger();
    try {
        int inputFiles = Configuration.USE_TRAIN;
        int outputFiles = 0;
        int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
        int modules = Configuration.USE_SRW;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        log.info("Parsing " + groundedFile + "...");
        long start = System.currentTimeMillis();
        final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
        final SRW srw = c.srw;
        final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
        srw.setEpoch(1);
        srw.clearLoss();
        srw.fixedWeightRules().addExact("id(restart)");
        srw.fixedWeightRules().addExact("id(trueLoop)");
        srw.fixedWeightRules().addExact("id(trueLoopRestart)");
        /* DiagSrwES: */
        ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
        final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        int i = 1;
        for (String s : new ParsedFile(groundedFile)) {
            final int id = i++;
            final String in = s;
            parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {

                @Override
                public PosNegRWExample call() throws Exception {
                    try {
                        //log.debug("Job start "+id);
                        //PosNegRWExample ret = parser.parse(in, b.copy());
                        log.debug("Parsing start " + id);
                        PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
                        log.debug("Parsing done " + id);
                        //log.debug("Job done "+id);
                        return ret;
                    } catch (IllegalArgumentException e) {
                        System.err.println("Problem with #" + id);
                        e.printStackTrace();
                    }
                    return null;
                }
            }));
        }
        parserPool.shutdown();
        i = 1;
        for (Future<PosNegRWExample> future : parsed) {
            final int id = i++;
            final Future<PosNegRWExample> in = future;
            trainerPool.submit(new Runnable() {

                @Override
                public void run() {
                    try {
                        PosNegRWExample ex = in.get();
                        log.debug("Training start " + id);
                        srw.trainOnExample(params, ex, status);
                        log.debug("Training done " + id);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        trainerPool.shutdown();
        try {
            parserPool.awaitTermination(7, TimeUnit.DAYS);
            trainerPool.awaitTermination(7, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            log.error("Interrupted?", e);
        }
        /* /SrwES */
        /* SrwTtwop: 
			final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads/2, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									@Override
									public void run() {
										try {
											final PosNegRWExample ex = in.get();
											log.debug("Cleanup start "+id);
											trainerPool.submit(new Runnable() {
													@Override
													public void run(){
														log.debug("Training start "+id);
														srw.trainOnExample(params,ex);
														log.debug("Training done "+id);
													}
												});
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			trainerPool.shutdown();
			try {
				trainerPool.awaitTermination(7, TimeUnit.DAYS);
			} catch (InterruptedException e) {
				log.error("Interrupted?",e);
			}

			 /SrwTtwop */
        /* all diag tasks except SrwO: 
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									log.debug("Training start "+id);
									srw.trainOnExample(params,ret);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        /* SrwO:
			   Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
			m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()), 
						 new Transformer<PosNegRWExample,Integer>() {
						@Override
						public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
							return new Callable<Integer>() {
								@Override
								public Integer call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Training start "+id);
									srw.trainOnExample(params,in);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return in.length();
								}};
						}}, new Cleanup<Integer>() {
							@Override
							public Runnable cleanup(final Future<Integer> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        srw.cleanupParams(params, params);
        log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ArrayList(java.util.ArrayList) Callable(java.util.concurrent.Callable) SRW(edu.cmu.ml.proppr.learn.SRW) RWExampleParser(edu.cmu.ml.proppr.learn.tools.RWExampleParser) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ExecutionException(java.util.concurrent.ExecutionException) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 5 with SRW

use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.

the class Trainer method cleanEpoch.

/**
	 * End-of-epoch cleanup routine shared by Trainer, CachingTrainer. 
	 * Shuts down working thread, cleaning thread, regularizer, loss calculations, stopper calculations, 
	 * training statistics, and zero gradient statistics.
	 * @param workingPool
	 * @param cleanPool
	 * @param paramVec
	 * @param traceLosses
	 * @param stopper
	 * @param n - number of examples
	 * @param stats
	 */
protected void cleanEpoch(ExecutorService workingPool, ExecutorService cleanPool, ParamVector<String, ?> paramVec, StoppingCriterion stopper, int n, TrainingStatistics stats) {
    n = n - 1;
    workingPool.shutdown();
    try {
        workingPool.awaitTermination(7, TimeUnit.DAYS);
        cleanPool.shutdown();
        cleanPool.awaitTermination(7, TimeUnit.DAYS);
    } catch (InterruptedException e) {
        e.printStackTrace();
    }
    // finish any trailing updates for this epoch
    // finish any trailing updates for this epoch
    this.masterLearner.cleanupParams(paramVec, paramVec);
    // loss status and signalling the stopper
    LossData lossThisEpoch = new LossData();
    for (SRW learner : this.learners.values()) {
        lossThisEpoch.add(learner.cumulativeLoss());
    }
    lossThisEpoch.convertCumulativesToAverage(statistics.numExamplesThisEpoch);
    printLossOutput(lossThisEpoch);
    if (epoch > 1) {
        stopper.recordConsecutiveLosses(lossThisEpoch, lossLastEpoch);
    }
    lossLastEpoch = lossThisEpoch;
    ZeroGradientData zeros = this.masterLearner.new ZeroGradientData();
    for (SRW learner : this.learners.values()) {
        zeros.add(learner.getZeroGradientData());
    }
    if (zeros.numZero > 0) {
        log.info(zeros.numZero + " / " + n + " examples with 0 gradient");
        if (zeros.numZero / (float) n > MAX_PCT_ZERO_GRADIENT)
            log.warn("Having this many 0 gradients is unusual for supervised tasks. Try a different squashing function?");
    }
    stopper.recordEpoch();
    statistics.checkStatistics();
    stats.updateReadingStatistics(statistics.readTime);
    stats.updateParsingStatistics(statistics.parseTime);
    stats.updateTrainingStatistics(statistics.trainTime);
}
Also used : LossData(edu.cmu.ml.proppr.learn.tools.LossData) ZeroGradientData(edu.cmu.ml.proppr.learn.SRW.ZeroGradientData) SRW(edu.cmu.ml.proppr.learn.SRW)

Aggregations

SRW (edu.cmu.ml.proppr.learn.SRW)7 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)4 ExecutorService (java.util.concurrent.ExecutorService)3 RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)2 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)2 StoppingCriterion (edu.cmu.ml.proppr.learn.tools.StoppingCriterion)2 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)2 NamedThreadFactory (edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory)2 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)2 Before (org.junit.Before)2 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)1 ZeroGradientData (edu.cmu.ml.proppr.learn.SRW.ZeroGradientData)1 LossData (edu.cmu.ml.proppr.learn.tools.LossData)1 RWExampleParser (edu.cmu.ml.proppr.learn.tools.RWExampleParser)1 ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)1 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)1 ArrayList (java.util.ArrayList)1 HashSet (java.util.HashSet)1 Callable (java.util.concurrent.Callable)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1