Search in sources :

Example 1 with ModuleConfiguration

use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.

the class Diagnostic method main.

public static void main(String[] args) {
    StatusLogger status = new StatusLogger();
    try {
        int inputFiles = Configuration.USE_TRAIN;
        int outputFiles = 0;
        int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
        int modules = Configuration.USE_SRW;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        log.info("Parsing " + groundedFile + "...");
        long start = System.currentTimeMillis();
        final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
        final SRW srw = c.srw;
        final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
        srw.setEpoch(1);
        srw.clearLoss();
        srw.fixedWeightRules().addExact("id(restart)");
        srw.fixedWeightRules().addExact("id(trueLoop)");
        srw.fixedWeightRules().addExact("id(trueLoopRestart)");
        /* DiagSrwES: */
        ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
        final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        int i = 1;
        for (String s : new ParsedFile(groundedFile)) {
            final int id = i++;
            final String in = s;
            parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {

                @Override
                public PosNegRWExample call() throws Exception {
                    try {
                        //log.debug("Job start "+id);
                        //PosNegRWExample ret = parser.parse(in, b.copy());
                        log.debug("Parsing start " + id);
                        PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
                        log.debug("Parsing done " + id);
                        //log.debug("Job done "+id);
                        return ret;
                    } catch (IllegalArgumentException e) {
                        System.err.println("Problem with #" + id);
                        e.printStackTrace();
                    }
                    return null;
                }
            }));
        }
        parserPool.shutdown();
        i = 1;
        for (Future<PosNegRWExample> future : parsed) {
            final int id = i++;
            final Future<PosNegRWExample> in = future;
            trainerPool.submit(new Runnable() {

                @Override
                public void run() {
                    try {
                        PosNegRWExample ex = in.get();
                        log.debug("Training start " + id);
                        srw.trainOnExample(params, ex, status);
                        log.debug("Training done " + id);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        trainerPool.shutdown();
        try {
            parserPool.awaitTermination(7, TimeUnit.DAYS);
            trainerPool.awaitTermination(7, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            log.error("Interrupted?", e);
        }
        /* /SrwES */
        /* SrwTtwop: 
			final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads/2, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									@Override
									public void run() {
										try {
											final PosNegRWExample ex = in.get();
											log.debug("Cleanup start "+id);
											trainerPool.submit(new Runnable() {
													@Override
													public void run(){
														log.debug("Training start "+id);
														srw.trainOnExample(params,ex);
														log.debug("Training done "+id);
													}
												});
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			trainerPool.shutdown();
			try {
				trainerPool.awaitTermination(7, TimeUnit.DAYS);
			} catch (InterruptedException e) {
				log.error("Interrupted?",e);
			}

			 /SrwTtwop */
        /* all diag tasks except SrwO: 
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									log.debug("Training start "+id);
									srw.trainOnExample(params,ret);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        /* SrwO:
			   Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
			m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()), 
						 new Transformer<PosNegRWExample,Integer>() {
						@Override
						public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
							return new Callable<Integer>() {
								@Override
								public Integer call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Training start "+id);
									srw.trainOnExample(params,in);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return in.length();
								}};
						}}, new Cleanup<Integer>() {
							@Override
							public Runnable cleanup(final Future<Integer> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        srw.cleanupParams(params, params);
        log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ArrayList(java.util.ArrayList) Callable(java.util.concurrent.Callable) SRW(edu.cmu.ml.proppr.learn.SRW) RWExampleParser(edu.cmu.ml.proppr.learn.tools.RWExampleParser) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ExecutionException(java.util.concurrent.ExecutionException) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 2 with ModuleConfiguration

use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.

the class PropertiesConfigurationTest method test.

@Test
public void test() {
    // config.properties defines train, test, params, prover, queries, force (unary), and two nonexistant options.
    System.setProperty(Configuration.PROPFILE, "src/testcases/config.properties");
    ModuleConfiguration c = new ModuleConfiguration("--prover dfs".split(" "), 0, Configuration.USE_PARAMS, Configuration.USE_FORCE, Configuration.USE_PROVER | Configuration.USE_SQUASHFUNCTION);
    assertTrue("Didn't fetch properties from file", c.paramsFile != null);
    assertTrue("Didn't prefer command line properties", c.prover instanceof DfsProver);
    assertTrue("Didn't fetch unary argument", c.force);
    assertEquals("Didn't fetch apr options properly", 0.01, c.apr.alpha, 1e-10);
}
Also used : ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Test(org.junit.Test)

Example 3 with ModuleConfiguration

use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.

the class FixedWeightRulesTest method srwTest.

@Test
public void srwTest() {
    ModuleConfiguration c = new ModuleConfiguration("--fixedWeights f(thing,pos)".split(" "), 0, 0, Configuration.USE_FIXEDWEIGHTS, Configuration.USE_SRW);
    assertTrue("Raw fixedWeightRules", c.fixedWeightRules.isFixed("f(thing,pos)"));
    assertFalse("in an SRW", c.srw.trainable("f(thing,pos)"));
}
Also used : ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) Test(org.junit.Test)

Example 4 with ModuleConfiguration

use of edu.cmu.ml.proppr.util.ModuleConfiguration in project ProPPR by TeamCohen.

the class Trainer method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_TRAIN | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_PARAMS;
        int constants = Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_THREADS | Configuration.USE_FIXEDWEIGHTS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        if (!c.queryFile.getName().endsWith(Grounder.GROUNDED_SUFFIX)) {
            throw new IllegalStateException("Run Grounder on " + c.queryFile.getName() + " first. Ground+Train in one go is not supported yet.");
        }
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(groundedFile + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        log.info("Training model parameters on " + groundedFile + "...");
        long start = System.currentTimeMillis();
        ParamVector<String, ?> params = c.trainer.train(masterFeatures, new ParsedFile(groundedFile), new ArrayLearningGraphBuilder(), c.initParamsFile, c.epochs);
        System.out.println("Training time: " + (System.currentTimeMillis() - start));
        if (c.paramsFile != null) {
            log.info("Saving parameters to " + c.paramsFile + "...");
            ParamsFile.save(params, c.paramsFile, c);
        }
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Aggregations

ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)4 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)2 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)2 Test (org.junit.Test)2 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)1 SRW (edu.cmu.ml.proppr.learn.SRW)1 RWExampleParser (edu.cmu.ml.proppr.learn.tools.RWExampleParser)1 DfsProver (edu.cmu.ml.proppr.prove.DfsProver)1 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)1 SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1 File (java.io.File)1 ArrayList (java.util.ArrayList)1 Callable (java.util.concurrent.Callable)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1 ExecutionException (java.util.concurrent.ExecutionException)1 ExecutorService (java.util.concurrent.ExecutorService)1 Future (java.util.concurrent.Future)1