Search in sources :

Example 1 with ParsedFile

use of edu.cmu.ml.proppr.util.ParsedFile in project ProPPR by TeamCohen.

the class Diagnostic method main.

public static void main(String[] args) {
    StatusLogger status = new StatusLogger();
    try {
        int inputFiles = Configuration.USE_TRAIN;
        int outputFiles = 0;
        int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
        int modules = Configuration.USE_SRW;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        log.info("Parsing " + groundedFile + "...");
        long start = System.currentTimeMillis();
        final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
        final SRW srw = c.srw;
        final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
        srw.setEpoch(1);
        srw.clearLoss();
        srw.fixedWeightRules().addExact("id(restart)");
        srw.fixedWeightRules().addExact("id(trueLoop)");
        srw.fixedWeightRules().addExact("id(trueLoopRestart)");
        /* DiagSrwES: */
        ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
        final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        int i = 1;
        for (String s : new ParsedFile(groundedFile)) {
            final int id = i++;
            final String in = s;
            parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {

                @Override
                public PosNegRWExample call() throws Exception {
                    try {
                        //log.debug("Job start "+id);
                        //PosNegRWExample ret = parser.parse(in, b.copy());
                        log.debug("Parsing start " + id);
                        PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
                        log.debug("Parsing done " + id);
                        //log.debug("Job done "+id);
                        return ret;
                    } catch (IllegalArgumentException e) {
                        System.err.println("Problem with #" + id);
                        e.printStackTrace();
                    }
                    return null;
                }
            }));
        }
        parserPool.shutdown();
        i = 1;
        for (Future<PosNegRWExample> future : parsed) {
            final int id = i++;
            final Future<PosNegRWExample> in = future;
            trainerPool.submit(new Runnable() {

                @Override
                public void run() {
                    try {
                        PosNegRWExample ex = in.get();
                        log.debug("Training start " + id);
                        srw.trainOnExample(params, ex, status);
                        log.debug("Training done " + id);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        trainerPool.shutdown();
        try {
            parserPool.awaitTermination(7, TimeUnit.DAYS);
            trainerPool.awaitTermination(7, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            log.error("Interrupted?", e);
        }
        /* /SrwES */
        /* SrwTtwop: 
			final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads/2, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									@Override
									public void run() {
										try {
											final PosNegRWExample ex = in.get();
											log.debug("Cleanup start "+id);
											trainerPool.submit(new Runnable() {
													@Override
													public void run(){
														log.debug("Training start "+id);
														srw.trainOnExample(params,ex);
														log.debug("Training done "+id);
													}
												});
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			trainerPool.shutdown();
			try {
				trainerPool.awaitTermination(7, TimeUnit.DAYS);
			} catch (InterruptedException e) {
				log.error("Interrupted?",e);
			}

			 /SrwTtwop */
        /* all diag tasks except SrwO: 
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									log.debug("Training start "+id);
									srw.trainOnExample(params,ret);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        /* SrwO:
			   Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
			m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()), 
						 new Transformer<PosNegRWExample,Integer>() {
						@Override
						public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
							return new Callable<Integer>() {
								@Override
								public Integer call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Training start "+id);
									srw.trainOnExample(params,in);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return in.length();
								}};
						}}, new Cleanup<Integer>() {
							@Override
							public Runnable cleanup(final Future<Integer> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        srw.cleanupParams(params, params);
        log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ArrayList(java.util.ArrayList) Callable(java.util.concurrent.Callable) SRW(edu.cmu.ml.proppr.learn.SRW) RWExampleParser(edu.cmu.ml.proppr.learn.tools.RWExampleParser) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ExecutionException(java.util.concurrent.ExecutionException) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 2 with ParsedFile

use of edu.cmu.ml.proppr.util.ParsedFile in project ProPPR by TeamCohen.

the class GradientFinder method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
        CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {

            boolean relax;

            @Override
            protected Option checkOption(Option o) {
                if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
                    o.setRequired(false);
                return o;
            }

            @Override
            protected void addCustomOptions(Options options, int[] flags) {
                options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
            }

            @Override
            protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
                if (groundedFile == null || !groundedFile.exists())
                    usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
                if (gradientFile == null)
                    usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
                // default to 0 epochs
                if (!options.hasOption("epochs"))
                    this.epochs = 0;
                this.relax = false;
                if (options.hasOption("relaxFW"))
                    this.relax = true;
            }

            @Override
            public Object getCustomSetting(String name) {
                if ("relaxFW".equals(name))
                    return this.relax;
                return null;
            }
        };
        System.out.println(c.toString());
        ParamVector<String, ?> params = null;
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        if (c.epochs > 0) {
            // train first
            log.info("Training for " + c.epochs + " epochs...");
            params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
            c.initParamsFile, c.epochs);
            if (c.paramsFile != null)
                ParamsFile.save(params, c.paramsFile, c);
        } else if (c.initParamsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
        } else if (c.paramsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
        } else {
            params = new SimpleParamVector<String>();
        }
        // this lets prongHorn hold external features fixed for training, but still compute their gradient
        if (((Boolean) c.getCustomSetting("relaxFW"))) {
            log.info("Turning off fixedWeight rules");
            c.trainer.setFixedWeightRules(new FixedWeightRules());
        }
        ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
        ParamsFile.save(batchGradient, c.gradientFile, c);
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : Options(org.apache.commons.cli.Options) CustomConfiguration(edu.cmu.ml.proppr.util.CustomConfiguration) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) CommandLine(org.apache.commons.cli.CommandLine) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) FixedWeightRules(edu.cmu.ml.proppr.learn.tools.FixedWeightRules) Option(org.apache.commons.cli.Option) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 3 with ParsedFile

use of edu.cmu.ml.proppr.util.ParsedFile in project ProPPR by TeamCohen.

the class Trainer method train.

public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, File initialParamVecFile, int numEpochs) {
    ParamVector<String, ?> initParams = null;
    if (initialParamVecFile != null) {
        log.info("loading initial params from " + initialParamVecFile);
        initParams = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(initialParamVecFile), new ConcurrentHashMap<String, Double>()));
    } else {
        initParams = createParamVector();
    }
    return train(masterFeatures, examples, builder, initParams, numEpochs);
}
Also used : ParsedFile(edu.cmu.ml.proppr.util.ParsedFile)

Example 4 with ParsedFile

use of edu.cmu.ml.proppr.util.ParsedFile in project ProPPR by TeamCohen.

the class LightweightGraphPlugin method load.

/** Return a simpleGraphComponent with all the components loaded from
        a file.  The format of the file is that each line is a tab-separated 
        triple of edgelabel, sourceNode, destNode. */
public static WamPlugin load(APROptions apr, File f, int duplicates) {
    GraphlikePlugin p = new LightweightGraphPlugin(apr, f.getName());
    ParsedFile parsed = new ParsedFile(f);
    BloomFilter<String> lines = null;
    if (duplicates > 0)
        lines = new BloomFilter<String>(1e-5, duplicates);
    boolean exceeds = false;
    for (String line : parsed) {
        String[] parts = line.split("\t");
        if (parts.length < 3)
            parsed.parseError("expected 3 tab-delimited fields; got " + parts.length);
        if (duplicates > 0) {
            if (lines.contains(line)) {
                log.warn("Skipping duplicate fact at " + f.getName() + ":" + parsed.getAbsoluteLineNumber() + ": " + line);
                continue;
            } else
                lines.add(line);
            if (!exceeds & parsed.getLineNumber() > duplicates) {
                exceeds = true;
                log.warn("Number of graph edges exceeds " + duplicates + "; duplicate detection may encounter false positives. We should add a command line option to fix this.");
            }
        }
        if (parts.length == 3) {
            p.addEdge(parts[0].trim(), parts[1].trim(), parts[2].trim());
        } else if (parts.length == 4) {
            p.addEdge(parts[0].trim(), parts[1].trim(), parts[2].trim(), Double.parseDouble(parts[3].trim()));
        }
    }
    return p;
}
Also used : ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) BloomFilter(com.skjegstad.utils.BloomFilter)

Example 5 with ParsedFile

use of edu.cmu.ml.proppr.util.ParsedFile in project ProPPR by TeamCohen.

the class SparseGraphPlugin method loadArgs.

private void loadArgs(TObjectIntMap<String> args, File file) {
    log.debug("Loading args file " + file.getName() + " in String...");
    ParsedFile parsed = new ParsedFile(file);
    for (String line : parsed) args.put((line.trim()), parsed.getLineNumber());
    parsed.close();
}
Also used : ParsedFile(edu.cmu.ml.proppr.util.ParsedFile)

Aggregations

ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)9 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)3 File (java.io.File)3 BloomFilter (com.skjegstad.utils.BloomFilter)2 ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)2 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)2 SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)2 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)2 ArrayList (java.util.ArrayList)2 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)1 SRW (edu.cmu.ml.proppr.learn.SRW)1 FixedWeightRules (edu.cmu.ml.proppr.learn.tools.FixedWeightRules)1 RWExampleParser (edu.cmu.ml.proppr.learn.tools.RWExampleParser)1 CustomConfiguration (edu.cmu.ml.proppr.util.CustomConfiguration)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 Iterator (java.util.Iterator)1 Callable (java.util.concurrent.Callable)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1 ExecutionException (java.util.concurrent.ExecutionException)1 ExecutorService (java.util.concurrent.ExecutorService)1