Search in sources :

Example 1 with ArrayLearningGraphBuilder

use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.

the class RedBlueGraph method setup.

@Before
public void setup() {
    //		if (!Logger.getRootLogger().getAllAppenders().hasMoreElements()) {
    //			BasicConfigurator.configure(); Logger.getRootLogger().setLevel(Level.WARN);
    //		}
    LearningGraphBuilder lgb = new ArrayLearningGraphBuilder();
    brGraph = (LearningGraph) lgb.create(new SimpleSymbolTable<String>());
    lgb.index(1);
    lgb.setGraphSize(brGraph, magicNumber * 2, -1, -1);
    //		brSRWs = new ArrayList<SRW>();
    //		Collections.addAll(brSRWs, new L2SqLossSRW(), new L2SqLossSRW(), new L2SqLossSRW());
    {
        int u = nodes.getId("b0"), v = nodes.getId("r0");
        HashMap<String, Double> ff = new HashMap<String, Double>();
        ff.put("fromb", 1.0);
        ff.put("tor", 1.0);
        lgb.addOutlink(brGraph, u, makeOutlink(lgb, ff, v));
        ff = new HashMap<String, Double>();
        ff.put("fromr", 1.0);
        ff.put("tob", 1.0);
        lgb.addOutlink(brGraph, v, makeOutlink(lgb, ff, u));
    }
    addColor(lgb, brGraph, magicNumber, "r");
    addColor(lgb, brGraph, magicNumber, "b");
    // save sets of red and blue nodes
    reds = new TreeSet<String>();
    blues = new TreeSet<String>();
    for (int ui = 1; ui < (2 * magicNumber + 1); ui++) {
        String u = nodes.getSymbol(ui);
        if (u.startsWith("b"))
            blues.add(u);
        else
            reds.add(u);
    }
    moreSetup(lgb);
    lgb.freeze(brGraph);
//			System.err.println("\n"+brGraphs.get(0).dump("r0"));
}
Also used : HashMap(java.util.HashMap) TObjectDoubleHashMap(gnu.trove.map.hash.TObjectDoubleHashMap) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder) LearningGraphBuilder(edu.cmu.ml.proppr.graph.LearningGraphBuilder) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder) Before(org.junit.Before)

Example 2 with ArrayLearningGraphBuilder

use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.

the class Diagnostic method main.

public static void main(String[] args) {
    StatusLogger status = new StatusLogger();
    try {
        int inputFiles = Configuration.USE_TRAIN;
        int outputFiles = 0;
        int constants = Configuration.USE_THREADS | Configuration.USE_THROTTLE;
        int modules = Configuration.USE_SRW;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        log.info("Parsing " + groundedFile + "...");
        long start = System.currentTimeMillis();
        final ArrayLearningGraphBuilder b = new ArrayLearningGraphBuilder();
        final SRW srw = c.srw;
        final ParamVector<String, ?> params = srw.setupParams(new SimpleParamVector<String>(new ConcurrentHashMap<String, Double>(16, (float) 0.75, 24)));
        srw.setEpoch(1);
        srw.clearLoss();
        srw.fixedWeightRules().addExact("id(restart)");
        srw.fixedWeightRules().addExact("id(trueLoop)");
        srw.fixedWeightRules().addExact("id(trueLoopRestart)");
        /* DiagSrwES: */
        ArrayList<Future<PosNegRWExample>> parsed = new ArrayList<Future<PosNegRWExample>>();
        final ExecutorService trainerPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads > 1 ? c.nthreads / 2 : 1);
        int i = 1;
        for (String s : new ParsedFile(groundedFile)) {
            final int id = i++;
            final String in = s;
            parsed.add(parserPool.submit(new Callable<PosNegRWExample>() {

                @Override
                public PosNegRWExample call() throws Exception {
                    try {
                        //log.debug("Job start "+id);
                        //PosNegRWExample ret = parser.parse(in, b.copy());
                        log.debug("Parsing start " + id);
                        PosNegRWExample ret = new RWExampleParser().parse(in, b.copy(), srw);
                        log.debug("Parsing done " + id);
                        //log.debug("Job done "+id);
                        return ret;
                    } catch (IllegalArgumentException e) {
                        System.err.println("Problem with #" + id);
                        e.printStackTrace();
                    }
                    return null;
                }
            }));
        }
        parserPool.shutdown();
        i = 1;
        for (Future<PosNegRWExample> future : parsed) {
            final int id = i++;
            final Future<PosNegRWExample> in = future;
            trainerPool.submit(new Runnable() {

                @Override
                public void run() {
                    try {
                        PosNegRWExample ex = in.get();
                        log.debug("Training start " + id);
                        srw.trainOnExample(params, ex, status);
                        log.debug("Training done " + id);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
        trainerPool.shutdown();
        try {
            parserPool.awaitTermination(7, TimeUnit.DAYS);
            trainerPool.awaitTermination(7, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            log.error("Interrupted?", e);
        }
        /* /SrwES */
        /* SrwTtwop: 
			final ExecutorService parserPool = Executors.newFixedThreadPool(c.nthreads>1?c.nthreads/2:1);
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads/2, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									@Override
									public void run() {
										try {
											final PosNegRWExample ex = in.get();
											log.debug("Cleanup start "+id);
											trainerPool.submit(new Runnable() {
													@Override
													public void run(){
														log.debug("Training start "+id);
														srw.trainOnExample(params,ex);
														log.debug("Training done "+id);
													}
												});
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			trainerPool.shutdown();
			try {
				trainerPool.awaitTermination(7, TimeUnit.DAYS);
			} catch (InterruptedException e) {
				log.error("Interrupted?",e);
			}

			 /SrwTtwop */
        /* all diag tasks except SrwO: 
			Multithreading<String,PosNegRWExample> m = new Multithreading<String,PosNegRWExample>(log);
			m.executeJob(c.nthreads, new ParsedFile(groundedFile), 
					new Transformer<String,PosNegRWExample>() {
						@Override
						public Callable<PosNegRWExample> transformer(final String in, final int id) {
							return new Callable<PosNegRWExample>() {
								@Override
								public PosNegRWExample call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Parsing start "+id);
									PosNegRWExample ret = new GroundedExampleParser().parse(in, b.copy());
									log.debug("Parsing done "+id);
									log.debug("Training start "+id);
									srw.trainOnExample(params,ret);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									return ret;
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return null;
								}};
						}}, new Cleanup<PosNegRWExample>() {
							@Override
							public Runnable cleanup(final Future<PosNegRWExample> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        /* SrwO:
			   Multithreading<PosNegRWExample,Integer> m = new Multithreading<PosNegRWExample,Integer>(log);
			m.executeJob(c.nthreads, new PosNegRWExampleStreamer(new ParsedFile(groundedFile),new ArrayLearningGraphBuilder()), 
						 new Transformer<PosNegRWExample,Integer>() {
						@Override
						public Callable<Integer> transformer(final PosNegRWExample in, final int id) {
							return new Callable<Integer>() {
								@Override
								public Integer call() throws Exception {
									try {
									//log.debug("Job start "+id);
									//PosNegRWExample ret = parser.parse(in, b.copy());
									log.debug("Training start "+id);
									srw.trainOnExample(params,in);
									log.debug("Training done "+id);
									//log.debug("Job done "+id);
									} catch (IllegalArgumentException e) {
										System.err.println("Problem with #"+id);
										e.printStackTrace();
									}
									return in.length();
								}};
						}}, new Cleanup<Integer>() {
							@Override
							public Runnable cleanup(final Future<Integer> in, final int id) {
								return new Runnable(){
									//ArrayList<PosNegRWExample> done = new ArrayList<PosNegRWExample>();
									@Override
									public void run() {
										try {
											//done.add(in.get());
											in.get();
										} catch (InterruptedException e) {
										    e.printStackTrace(); 
										} catch (ExecutionException e) {
										    e.printStackTrace();
										}
										log.debug("Cleanup start "+id);
										log.debug("Cleanup done "+id);
									}};
							}}, c.throttle);
			*/
        srw.cleanupParams(params, params);
        log.info("Finished diagnostic in " + (System.currentTimeMillis() - start) + " ms");
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ArrayList(java.util.ArrayList) Callable(java.util.concurrent.Callable) SRW(edu.cmu.ml.proppr.learn.SRW) RWExampleParser(edu.cmu.ml.proppr.learn.tools.RWExampleParser) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ExecutionException(java.util.concurrent.ExecutionException) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) ExecutorService(java.util.concurrent.ExecutorService) Future(java.util.concurrent.Future) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 3 with ArrayLearningGraphBuilder

use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.

the class GradientFinder method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
        CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {

            boolean relax;

            @Override
            protected Option checkOption(Option o) {
                if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
                    o.setRequired(false);
                return o;
            }

            @Override
            protected void addCustomOptions(Options options, int[] flags) {
                options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
            }

            @Override
            protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
                if (groundedFile == null || !groundedFile.exists())
                    usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
                if (gradientFile == null)
                    usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
                // default to 0 epochs
                if (!options.hasOption("epochs"))
                    this.epochs = 0;
                this.relax = false;
                if (options.hasOption("relaxFW"))
                    this.relax = true;
            }

            @Override
            public Object getCustomSetting(String name) {
                if ("relaxFW".equals(name))
                    return this.relax;
                return null;
            }
        };
        System.out.println(c.toString());
        ParamVector<String, ?> params = null;
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        if (c.epochs > 0) {
            // train first
            log.info("Training for " + c.epochs + " epochs...");
            params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
            c.initParamsFile, c.epochs);
            if (c.paramsFile != null)
                ParamsFile.save(params, c.paramsFile, c);
        } else if (c.initParamsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
        } else if (c.paramsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
        } else {
            params = new SimpleParamVector<String>();
        }
        // this lets prongHorn hold external features fixed for training, but still compute their gradient
        if (((Boolean) c.getCustomSetting("relaxFW"))) {
            log.info("Turning off fixedWeight rules");
            c.trainer.setFixedWeightRules(new FixedWeightRules());
        }
        ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
        ParamsFile.save(batchGradient, c.gradientFile, c);
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : Options(org.apache.commons.cli.Options) CustomConfiguration(edu.cmu.ml.proppr.util.CustomConfiguration) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) CommandLine(org.apache.commons.cli.CommandLine) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) FixedWeightRules(edu.cmu.ml.proppr.learn.tools.FixedWeightRules) Option(org.apache.commons.cli.Option) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 4 with ArrayLearningGraphBuilder

use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.

the class SRWRestartTest method moreOutlinks.

@Override
public void moreOutlinks(LearningGraphBuilder lgb, LearningGraph graph, int u) {
    ArrayLearningGraphBuilder b = ((ArrayLearningGraphBuilder) lgb);
    // first see if we already have a reset link (covered by makeOutlink())
    int r0 = nodes.getId("r0");
    if (b.outlinks[u] != null) {
        for (RWOutlink o : b.outlinks[u]) {
            if (o.nodeid == r0) {
                return;
            }
        }
    }
    // no reset link exists, so add one
    lgb.addOutlink(graph, u, new RWOutlink(new int[] { lgb.getFeatureLibrary().getId("id(restart)") }, new double[] { this.srw.getSquashingFunction().defaultValue() }, r0));
}
Also used : RWOutlink(edu.cmu.ml.proppr.graph.RWOutlink) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 5 with ArrayLearningGraphBuilder

use of edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder in project ProPPR by TeamCohen.

the class Trainer method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_TRAIN | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_PARAMS;
        int constants = Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_THREADS | Configuration.USE_FIXEDWEIGHTS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
        log.info(c.toString());
        String groundedFile = c.queryFile.getPath();
        if (!c.queryFile.getName().endsWith(Grounder.GROUNDED_SUFFIX)) {
            throw new IllegalStateException("Run Grounder on " + c.queryFile.getName() + " first. Ground+Train in one go is not supported yet.");
        }
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(groundedFile + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        log.info("Training model parameters on " + groundedFile + "...");
        long start = System.currentTimeMillis();
        ParamVector<String, ?> params = c.trainer.train(masterFeatures, new ParsedFile(groundedFile), new ArrayLearningGraphBuilder(), c.initParamsFile, c.epochs);
        System.out.println("Training time: " + (System.currentTimeMillis() - start));
        if (c.paramsFile != null) {
            log.info("Saving parameters to " + c.paramsFile + "...");
            ParamsFile.save(params, c.paramsFile, c);
        }
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Aggregations

ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)5 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)3 ModuleConfiguration (edu.cmu.ml.proppr.util.ModuleConfiguration)2 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)2 SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)2 File (java.io.File)2 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)1 LearningGraphBuilder (edu.cmu.ml.proppr.graph.LearningGraphBuilder)1 RWOutlink (edu.cmu.ml.proppr.graph.RWOutlink)1 SRW (edu.cmu.ml.proppr.learn.SRW)1 FixedWeightRules (edu.cmu.ml.proppr.learn.tools.FixedWeightRules)1 RWExampleParser (edu.cmu.ml.proppr.learn.tools.RWExampleParser)1 CustomConfiguration (edu.cmu.ml.proppr.util.CustomConfiguration)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)1 TObjectDoubleHashMap (gnu.trove.map.hash.TObjectDoubleHashMap)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Callable (java.util.concurrent.Callable)1