use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.
the class GradientFinderTest method setup.
@Before
public void setup() {
super.setup();
this.srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new RegularizeL2()));
this.srw.setSquashingFunction(new ReLU<String>());
this.initTrainer();
query = new TIntDoubleHashMap();
query.put(nodes.getId("r0"), 1.0);
examples = new ArrayList<String>();
for (int k = 0; k < this.magicNumber; k++) {
for (int p = 0; p < this.magicNumber; p++) {
StringBuilder serialized = new StringBuilder("r0").append("\t").append(//query
nodes.getId("r0")).append("\t").append(// pos
nodes.getId("b" + k)).append("\t").append(//neg
nodes.getId("r" + p)).append("\t").append(// nodes
brGraph.nodeSize()).append("\t").append(//edges
brGraph.edgeSize()).append(// waiting for .append(-1) // label dependencies
"\t");
int labelDependencies = 0;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < brGraph.getFeatureSet().size(); i++) {
if (i > 0)
sb.append(":");
sb.append(brGraph.featureLibrary.getSymbol(i + 1));
}
for (int u = 0; u < brGraph.node_hi; u++) {
HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
for (int ec = brGraph.node_near_lo[u]; ec < brGraph.node_near_hi[u]; ec++) {
int v = brGraph.edge_dest[ec];
sb.append("\t").append(u).append("->").append(v).append(":");
for (int lc = brGraph.edge_labels_lo[ec]; lc < brGraph.edge_labels_hi[ec]; lc++) {
outgoingFeatures.add(brGraph.label_feature_id[lc]);
if (lc > brGraph.edge_labels_lo[ec])
sb.append(",");
sb.append(brGraph.label_feature_id[lc]).append("@").append(brGraph.label_feature_weight[lc]);
}
}
labelDependencies += outgoingFeatures.size() * (brGraph.node_near_hi[u] - brGraph.node_near_lo[u]);
}
serialized.append(labelDependencies).append("\t").append(sb);
examples.add(serialized.toString());
}
}
}
use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.
the class SRWTest method initSrw.
public void initSrw() {
srw = new SRW();
this.srw.setRegularizer(new RegularizationSchedule(this.srw, new Regularize()));
srw.c.apr.maxDepth = 10;
}
use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.
the class CachingTrainer method trainCached.
public ParamVector<String, ?> trainCached(List<PosNegRWExample> examples, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs, TrainingStatistics total) {
ParamVector<String, ?> paramVec = this.masterLearner.setupParams(initialParamVec);
NamedThreadFactory trainThreads = new NamedThreadFactory("work-");
ExecutorService trainPool;
ExecutorService cleanPool;
StoppingCriterion stopper = new StoppingCriterion(numEpochs, this.stoppingPercent, this.stoppingEpoch);
boolean graphSizesStatusLog = true;
// repeat until ready to stop
while (!stopper.satisified()) {
// set up current epoch
this.epoch++;
for (SRW learner : this.learners.values()) {
learner.setEpoch(epoch);
learner.clearLoss();
}
log.info("epoch " + epoch + " ...");
status.tick();
// reset counters & file pointers
this.statistics = new TrainingStatistics();
trainThreads.reset();
trainPool = Executors.newFixedThreadPool(this.nthreads, trainThreads);
cleanPool = Executors.newSingleThreadExecutor();
// run examples
int id = 1;
if (this.shuffle)
Collections.shuffle(examples);
for (PosNegRWExample s : examples) {
Future<ExampleStats> trained = trainPool.submit(new Train(new PretendParse(s), paramVec, id, null));
cleanPool.submit(new TraceLosses(trained, id));
id++;
if (log.isInfoEnabled() && status.due(1))
log.info("queued: " + id + " trained: " + statistics.exampleSetSize);
}
cleanEpoch(trainPool, cleanPool, paramVec, stopper, id, total);
if (graphSizesStatusLog) {
log.info("Dataset size stats: " + statistics.totalGraphSize + " total nodes / max " + statistics.maxGraphSize + " / avg " + (statistics.totalGraphSize / id));
graphSizesStatusLog = false;
}
}
log.info("Reading: " + total.readTime + " Parsing: " + total.parseTime + " Training: " + total.trainTime);
return paramVec;
}
use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.
the class Diagnostic method main.
public static void main(String[] args) {
StatusLogger status = new StatusLogger();
try {
int inputFiles = Configuration.USE_TRAIN;
int outputFiles = 0;
int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
int modules = Configuration.USE_SRW;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
log.info("Parsing " + groundedFile + "...");
long start = System.currentTimeMillis();
final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
final SRW srw = c.srw;
final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
srw.setEpoch(1);
srw.clearLoss();
srw.fixedWeightRules().addExact("id(restart)");
srw.fixedWeightRules().addExact("id(trueLoop)");
srw.fixedWeightRules().addExact("id(trueLoopRestart)");
/* DiagSrwES: */
ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
int i = 1;
for (String s : new ParsedFile(groundedFile)) {
final int id = i++;
final String in = s;
parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start " + id);
PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
log.debug("Parsing done " + id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #" + id);
e.printStackTrace();
}
return null;
}
}));
}
parserPool.shutdown();
i = 1;
for (Future<PosNegRWExample> future : parsed) {
final int id = i++;
final Future<PosNegRWExample> in = future;
trainerPool.submit(new Runnable() {
@Override
public void run() {
try {
PosNegRWExample ex = in.get();
log.debug("Training start " + id);
srw.trainOnExample(params, ex, status);
log.debug("Training done " + id);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
});
}
trainerPool.shutdown();
try {
parserPool.awaitTermination(7, TimeUnit.DAYS);
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
/* /SrwES */
/* SrwTtwop:
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads/2, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
@Override
public void run() {
try {
final PosNegRWExample ex = in.get();
log.debug("Cleanup start "+id);
trainerPool.submit(new Runnable() {
@Override
public void run(){
log.debug("Training start "+id);
srw.trainOnExample(params,ex);
log.debug("Training done "+id);
}
});
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
trainerPool.shutdown();
try {
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?",e);
}
/SrwTtwop */
/* all diag tasks except SrwO:
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
log.debug("Training start "+id);
srw.trainOnExample(params,ret);
log.debug("Training done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
/* SrwO:
Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()),
new Transformer<PosNegRWExample,Integer>() {
@Override
public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
return new Callable<Integer>() {
@Override
public Integer call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Training start "+id);
srw.trainOnExample(params,in);
log.debug("Training done "+id);
//log.debug("Job done "+id);
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return in.length();
}};
}}, new Cleanup<Integer>() {
@Override
public Runnable cleanup(final Future<Integer> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
srw.cleanupParams(params, params);
log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.learn.SRW in project ProPPR by TeamCohen.
the class Trainer method cleanEpoch.
/**
* End-of-epoch cleanup routine shared by Trainer, CachingTrainer.
* Shuts down working thread, cleaning thread, regularizer, loss calculations, stopper calculations,
* training statistics, and zero gradient statistics.
* @param workingPool
* @param cleanPool
* @param paramVec
* @param traceLosses
* @param stopper
* @param n - number of examples
* @param stats
*/
protected void cleanEpoch(ExecutorService workingPool, ExecutorService cleanPool, ParamVector<String, ?> paramVec, StoppingCriterion stopper, int n, TrainingStatistics stats) {
n = n - 1;
workingPool.shutdown();
try {
workingPool.awaitTermination(7, TimeUnit.DAYS);
cleanPool.shutdown();
cleanPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
e.printStackTrace();
}
// finish any trailing updates for this epoch
// finish any trailing updates for this epoch
this.masterLearner.cleanupParams(paramVec, paramVec);
// loss status and signalling the stopper
LossData lossThisEpoch = new LossData();
for (SRW learner : this.learners.values()) {
lossThisEpoch.add(learner.cumulativeLoss());
}
lossThisEpoch.convertCumulativesToAverage(statistics.numExamplesThisEpoch);
printLossOutput(lossThisEpoch);
if (epoch > 1) {
stopper.recordConsecutiveLosses(lossThisEpoch, lossLastEpoch);
}
lossLastEpoch = lossThisEpoch;
ZeroGradientData zeros = this.masterLearner.new ZeroGradientData();
for (SRW learner : this.learners.values()) {
zeros.add(learner.getZeroGradientData());
}
if (zeros.numZero > 0) {
log.info(zeros.numZero + " / " + n + " examples with 0 gradient");
if (zeros.numZero / (float) n > MAX_PCT_ZERO_GRADIENT)
log.warn("Having this many 0 gradients is unusual for supervised tasks. Try a different squashing function?");
}
stopper.recordEpoch();
statistics.checkStatistics();
stats.updateReadingStatistics(statistics.readTime);
stats.updateParsingStatistics(statistics.parseTime);
stats.updateTrainingStatistics(statistics.trainTime);
}
Aggregations