use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class Diagnostic method main.
public static void main(String[] args) {
StatusLogger status = new StatusLogger();
try {
int inputFiles = Configuration.USE_TRAIN;
int outputFiles = 0;
int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
int modules = Configuration.USE_SRW;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
log.info("Parsing " + groundedFile + "...");
long start = System.currentTimeMillis();
final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
final SRW srw = c.srw;
final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
srw.setEpoch(1);
srw.clearLoss();
srw.fixedWeightRules().addExact("id(restart)");
srw.fixedWeightRules().addExact("id(trueLoop)");
srw.fixedWeightRules().addExact("id(trueLoopRestart)");
/* DiagSrwES: */
ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
int i = 1;
for (String s : new ParsedFile(groundedFile)) {
final int id = i++;
final String in = s;
parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start " + id);
PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
log.debug("Parsing done " + id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #" + id);
e.printStackTrace();
}
return null;
}
}));
}
parserPool.shutdown();
i = 1;
for (Future<PosNegRWExample> future : parsed) {
final int id = i++;
final Future<PosNegRWExample> in = future;
trainerPool.submit(new Runnable() {
@Override
public void run() {
try {
PosNegRWExample ex = in.get();
log.debug("Training start " + id);
srw.trainOnExample(params, ex, status);
log.debug("Training done " + id);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
});
}
trainerPool.shutdown();
try {
parserPool.awaitTermination(7, TimeUnit.DAYS);
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
/* /SrwES */
/* SrwTtwop:
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads/2, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
@Override
public void run() {
try {
final PosNegRWExample ex = in.get();
log.debug("Cleanup start "+id);
trainerPool.submit(new Runnable() {
@Override
public void run(){
log.debug("Training start "+id);
srw.trainOnExample(params,ex);
log.debug("Training done "+id);
}
});
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
trainerPool.shutdown();
try {
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?",e);
}
/SrwTtwop */
/* all diag tasks except SrwO:
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
log.debug("Training start "+id);
srw.trainOnExample(params,ret);
log.debug("Training done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
/* SrwO:
Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()),
new Transformer<PosNegRWExample,Integer>() {
@Override
public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
return new Callable<Integer>() {
@Override
public Integer call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Training start "+id);
srw.trainOnExample(params,in);
log.debug("Training done "+id);
//log.debug("Job done "+id);
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return in.length();
}};
}}, new Cleanup<Integer>() {
@Override
public Runnable cleanup(final Future<Integer> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
srw.cleanupParams(params, params);
log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class CachingTrainer method train.
@Override
public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> exampleFile, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs) {
ArrayList<PosNegRWExample> examples = new ArrayList<PosNegRWExample>();
RWExampleParser parser = new RWExampleParser();
if (masterFeatures.size() > 0)
LearningGraphBuilder.setFeatures(masterFeatures);
int id = 0;
StatusLogger stattime = new StatusLogger();
TrainingStatistics total = new TrainingStatistics();
boolean logged = false;
for (String s : exampleFile) {
total.updateReadingStatistics(stattime.sinceLast());
id++;
try {
stattime.tick();
PosNegRWExample ex = parser.parse(s, builder, masterLearner);
total.updateParsingStatistics(stattime.sinceLast());
examples.add(ex);
if (status.due()) {
log.info("Parsed " + id + " ...");
logged = true;
}
} catch (GraphFormatException e) {
log.error("Trouble with #" + id, e);
}
stattime.tick();
}
if (logged)
log.info("Total parsed: " + id);
return trainCached(examples, builder, initialParamVec, numEpochs, total);
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class SimpleProgramProverTest method test.
@Test
public void test() throws IOException, LogicProgramException {
WamProgram program = WamBaseProgram.load(new File(PROGRAM));
Query q = new Query(new Goal("coworker", new ConstantArgument("steve"), new ConstantArgument("X")));
System.out.println("Query: " + q.toString());
ProofGraph p = new StateProofGraph(q, apr, program);
Prover prover = new DfsProver(apr);
Map<String, Double> sols = prover.solutions(p, new StatusLogger());
assertEquals(2, sols.size());
HashMap<String, Integer> expected = new HashMap<String, Integer>();
expected.put("steve", 0);
expected.put("sven", 0);
System.out.println("Query: " + q.toString());
for (String pair : sols.keySet()) {
System.out.println(pair);
String[] parts = pair.split(":");
String v = parts[1];
System.out.println("Got solution: " + v);
if (expected.containsKey(v))
expected.put(v, expected.get(v) + 1);
}
for (Map.Entry<String, Integer> e : expected.entrySet()) assertEquals(e.getKey(), 1, e.getValue().intValue());
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class FactsPluginTest method test.
@Test
public void test() throws LogicProgramException {
APROptions apr = new APROptions();
FactsPlugin p = FactsPlugin.load(apr, new File(GrounderTest.FACTS), false);
WamProgram program = new WamBaseProgram();
Query q = Query.parse("validClass(X)");
StateProofGraph pg = new StateProofGraph(q, apr, program, p);
Prover prover = new DprProver();
Map<String, Double> sols = prover.solutions(pg, new StatusLogger());
assertEquals(2, sols.size());
}
use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.
the class SRWTest method makeLoss.
public double makeLoss(SRW srw, ParamVector<String, ?> paramVec, TIntDoubleMap query, int[] pos, int[] neg) {
srw.clearLoss();
srw.accumulateGradient(paramVec, factory.makeExample("loss", brGraph, query, pos, neg), new SimpleParamVector<String>(), new StatusLogger());
return srw.cumulativeLoss().total();
}
Aggregations