use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.
the class Diagnostic method main.
public static void main(String[] args) {
StatusLogger status = new StatusLogger();
try {
int inputFiles = Configuration.USE_TRAIN;
int outputFiles = 0;
int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
int modules = Configuration.USE_SRW;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
log.info("Parsing " + groundedFile + "...");
long start = System.currentTimeMillis();
final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
final SRW srw = c.srw;
final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
srw.setEpoch(1);
srw.clearLoss();
srw.fixedWeightRules().addExact("id(restart)");
srw.fixedWeightRules().addExact("id(trueLoop)");
srw.fixedWeightRules().addExact("id(trueLoopRestart)");
/* DiagSrwES: */
ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
int i = 1;
for (String s : new ParsedFile(groundedFile)) {
final int id = i++;
final String in = s;
parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start " + id);
PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
log.debug("Parsing done " + id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #" + id);
e.printStackTrace();
}
return null;
}
}));
}
parserPool.shutdown();
i = 1;
for (Future<PosNegRWExample> future : parsed) {
final int id = i++;
final Future<PosNegRWExample> in = future;
trainerPool.submit(new Runnable() {
@Override
public void run() {
try {
PosNegRWExample ex = in.get();
log.debug("Training start " + id);
srw.trainOnExample(params, ex, status);
log.debug("Training done " + id);
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
});
}
trainerPool.shutdown();
try {
parserPool.awaitTermination(7, TimeUnit.DAYS);
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
/* /SrwES */
/* SrwTtwop:
final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads/2, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
@Override
public void run() {
try {
final PosNegRWExample ex = in.get();
log.debug("Cleanup start "+id);
trainerPool.submit(new Runnable() {
@Override
public void run(){
log.debug("Training start "+id);
srw.trainOnExample(params,ex);
log.debug("Training done "+id);
}
});
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
trainerPool.shutdown();
try {
trainerPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?",e);
}
/SrwTtwop */
/* all diag tasks except SrwO:
Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
m.executeJob(c.nthreads, new ParsedFile(groundedFile),
new Transformer<String,PosNegRWExample>() {
@Override
public Callable<PosNegRWExample> transformer(final String in, final int id) {
return new Callable<PosNegRWExample>() {
@Override
public PosNegRWExample call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Parsing start "+id);
PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
log.debug("Parsing done "+id);
log.debug("Training start "+id);
srw.trainOnExample(params,ret);
log.debug("Training done "+id);
//log.debug("Job done "+id);
return ret;
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return null;
}};
}}, new Cleanup<PosNegRWExample>() {
@Override
public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
/* SrwO:
Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()),
new Transformer<PosNegRWExample,Integer>() {
@Override
public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
return new Callable<Integer>() {
@Override
public Integer call() throws Exception {
try {
//log.debug("Job start "+id);
//PosNegRWExample ret = parser.parse(in, b.copy());
log.debug("Training start "+id);
srw.trainOnExample(params,in);
log.debug("Training done "+id);
//log.debug("Job done "+id);
} catch (IllegalArgumentException e) {
System.err.println("Problem with #"+id);
e.printStackTrace();
}
return in.length();
}};
}}, new Cleanup<Integer>() {
@Override
public Runnable cleanup(final Future<Integer> in, final int id) {
return new Runnable(){
//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
@Override
public void run() {
try {
//done.add(in.get());
in.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
log.debug("Cleanup start "+id);
log.debug("Cleanup done "+id);
}};
}}, c.throttle);
*/
srw.cleanupParams(params, params);
log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.
the class PropertiesConfigurationTest method test.
@Test
public void test() {
// config.properties defines train, test, params, prover, queries, force (unary), and two nonexistant options.
System.setProperty(Configuration.PROPFILE, "src/testcases/config.properties");
ModuleConfiguration c = new ModuleConfiguration("--prover dfs".split(" "), 0, Configuration.USE_PARAMS, Configuration.USE_FORCE, Configuration.USE_PROVER | Configuration.USE_SQUASHFUNCTION);
assertTrue("Didn't fetch properties from file", c.paramsFile != null);
assertTrue("Didn't prefer command line properties", c.prover instanceof DfsProver);
assertTrue("Didn't fetch unary argument", c.force);
assertEquals("Didn't fetch apr options properly", 0.01, c.apr.alpha, 1e-10);
}
use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.
the class FixedWeightRulesTest method srwTest.
@Test
public void srwTest() {
ModuleConfiguration c = new ModuleConfiguration("--fixedWeights f(thing,pos)".split(" "), 0, 0, Configuration.USE_FIXEDWEIGHTS, Configuration.USE_SRW);
assertTrue("Raw fixedWeightRules", c.fixedWeightRules.isFixed("f(thing,pos)"));
assertFalse("in an SRW", c.srw.trainable("f(thing,pos)"));
}
use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.
the class Trainer method main.
public static void main(String[] args) {
try {
int inputFiles = Configuration.USE_TRAIN | Configuration.USE_INIT_PARAMS;
int outputFiles = Configuration.USE_PARAMS;
int constants = Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_THREADS | Configuration.USE_FIXEDWEIGHTS;
int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
if (!c.queryFile.getName().endsWith(Grounder.GROUNDED_SUFFIX)) {
throw new IllegalStateException("Run Grounder on " + c.queryFile.getName() + " first. Ground+Train in one go is not supported yet.");
}
SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
File featureIndex = new File(groundedFile + Grounder.FEATURE_INDEX_EXTENSION);
if (featureIndex.exists()) {
log.info("Reading feature index from " + featureIndex.getName() + "...");
for (String line : new ParsedFile(featureIndex)) {
masterFeatures.insert(line.trim());
}
}
log.info("Training model parameters on " + groundedFile + "...");
long start = System.currentTimeMillis();
ParamVector<String, ?> params = c.trainer.train(masterFeatures, new ParsedFile(groundedFile), new ArrayLearningGraphBuilder(), c.initParamsFile, c.epochs);
System.out.println("Training time: " + (System.currentTimeMillis() - start));
if (c.paramsFile != null) {
log.info("Saving parameters to " + c.paramsFile + "...");
ParamsFile.save(params, c.paramsFile, c);
}
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
Aggregations