Search in sources :

Example 1 with SimpleParamVector

use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.

the class FiniteDifferenceTest method makeGradient.

public ParamVector<String, ?> makeGradient(SRW srw, ParamVector<String, ?> paramVec, TIntDoubleMap query, int[] pos, int[] neg, ExampleFactory f) {
    ParamVector<String, ?> grad = new SimpleParamVector<String>();
    srw.accumulateGradient(paramVec, f.makeExample("gradient", brGraph, query, pos, neg), grad, new StatusLogger());
    return grad;
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector)

Example 2 with SimpleParamVector

use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.

the class SRWTest method makeGradient.

//	 Test removed: We no longer compute rwr in SRW
//	
//	/**
//	 * Uniform weights should be the same as the unparameterized basic RWR
//	 */
//	@Test
//	public void testUniformRWR() {
//		log.debug("Test logging");
//		int maxT = 10;
//		
//		TIntDoubleMap baseLineVec = myRWR(startVec,brGraph,maxT);
//		uniformParams.put("id(restart)",srw.getWeightingScheme().defaultWeight());
//		TIntDoubleMap newVec = srw.rwrUsingFeatures(brGraph, startVec, uniformParams);
//		equalScores(baseLineVec,newVec);
//	}
//	
//	public ParamVector<String,?> makeParams(Map<String,Double> foo) {
//		return new SimpleParamVector(foo);
//	}
//	
//	public ParamVector<String,?> makeParams() {
//		return new SimpleParamVector();
//	}
public ParamVector<String, ?> makeGradient(SRW srw, ParamVector<String, ?> paramVec, TIntDoubleMap query, int[] pos, int[] neg) {
    ParamVector<String, ?> grad = new SimpleParamVector<String>();
    srw.accumulateGradient(paramVec, factory.makeExample("gradient", brGraph, query, pos, neg), grad, new StatusLogger());
    return grad;
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector)

Example 3 with SimpleParamVector

use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.

the class GradientFinder method main.

public static void main(String[] args) {
    try {
        int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
        int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
        int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
        int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
        CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {

            boolean relax;

            @Override
            protected Option checkOption(Option o) {
                if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
                    o.setRequired(false);
                return o;
            }

            @Override
            protected void addCustomOptions(Options options, int[] flags) {
                options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
            }

            @Override
            protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
                if (groundedFile == null || !groundedFile.exists())
                    usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
                if (gradientFile == null)
                    usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
                // default to 0 epochs
                if (!options.hasOption("epochs"))
                    this.epochs = 0;
                this.relax = false;
                if (options.hasOption("relaxFW"))
                    this.relax = true;
            }

            @Override
            public Object getCustomSetting(String name) {
                if ("relaxFW".equals(name))
                    return this.relax;
                return null;
            }
        };
        System.out.println(c.toString());
        ParamVector<String, ?> params = null;
        SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
        File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
        if (featureIndex.exists()) {
            log.info("Reading feature index from " + featureIndex.getName() + "...");
            for (String line : new ParsedFile(featureIndex)) {
                masterFeatures.insert(line.trim());
            }
        }
        if (c.epochs > 0) {
            // train first
            log.info("Training for " + c.epochs + " epochs...");
            params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
            c.initParamsFile, c.epochs);
            if (c.paramsFile != null)
                ParamsFile.save(params, c.paramsFile, c);
        } else if (c.initParamsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
        } else if (c.paramsFile != null) {
            params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
        } else {
            params = new SimpleParamVector<String>();
        }
        // this lets prongHorn hold external features fixed for training, but still compute their gradient
        if (((Boolean) c.getCustomSetting("relaxFW"))) {
            log.info("Turning off fixedWeight rules");
            c.trainer.setFixedWeightRules(new FixedWeightRules());
        }
        ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
        ParamsFile.save(batchGradient, c.gradientFile, c);
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : Options(org.apache.commons.cli.Options) CustomConfiguration(edu.cmu.ml.proppr.util.CustomConfiguration) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) CommandLine(org.apache.commons.cli.CommandLine) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) FixedWeightRules(edu.cmu.ml.proppr.learn.tools.FixedWeightRules) Option(org.apache.commons.cli.Option) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) File(java.io.File) ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) ParsedFile(edu.cmu.ml.proppr.util.ParsedFile) ArrayLearningGraphBuilder(edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)

Example 4 with SimpleParamVector

use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.

the class Trainer method findGradient.

public ParamVector<String, ?> findGradient(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> paramVec) {
    log.info("Computing gradient on cooked examples...");
    ParamVector<String, ?> sumGradient = new SimpleParamVector<String>();
    if (paramVec == null) {
        paramVec = createParamVector();
    }
    paramVec = this.masterLearner.setupParams(paramVec);
    if (masterFeatures != null && masterFeatures.size() > 0)
        LearningGraphBuilder.setFeatures(masterFeatures);
    //		
    //		//WW: accumulate example-size normalized gradient
    //		for (PosNegRWExample x : examples) {
    ////			this.learner.initializeFeatures(paramVec,x.getGraph());
    //			this.learner.accumulateGradient(paramVec, x, sumGradient);
    //			k++;
    //		}
    NamedThreadFactory workThreads = new NamedThreadFactory("work-");
    ExecutorService workPool, cleanPool;
    workPool = Executors.newFixedThreadPool(this.nthreads, workThreads);
    cleanPool = Executors.newSingleThreadExecutor();
    // run examples
    int id = 1;
    int countdown = -1;
    Trainer notify = null;
    status.start();
    for (String s : examples) {
        if (log.isInfoEnabled() && status.due())
            log.info(id + " examples read...");
        long queueSize = (((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount());
        if (log.isDebugEnabled())
            log.debug("Queue size " + queueSize);
        if (countdown > 0) {
            if (log.isDebugEnabled())
                log.debug("Countdown " + countdown);
            countdown--;
        } else if (countdown == 0) {
            if (log.isDebugEnabled())
                log.debug("Countdown " + countdown + "; throttling:");
            countdown--;
            notify = null;
            try {
                synchronized (this) {
                    if (log.isDebugEnabled())
                        log.debug("Clearing training queue...");
                    while ((((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount()) > this.nthreads) this.wait();
                    if (log.isDebugEnabled())
                        log.debug("Queue cleared.");
                }
            } catch (InterruptedException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        } else if (queueSize > 1.5 * this.nthreads) {
            if (log.isDebugEnabled())
                log.debug("Starting countdown");
            countdown = this.nthreads;
            notify = this;
        }
        Future<PosNegRWExample> parsed = workPool.submit(new Parse(s, builder, id));
        Future<ExampleStats> gradfound = workPool.submit(new Grad(parsed, paramVec, sumGradient, id, notify));
        cleanPool.submit(new TraceLosses(gradfound, id));
        id++;
    }
    workPool.shutdown();
    try {
        workPool.awaitTermination(7, TimeUnit.DAYS);
        cleanPool.shutdown();
        cleanPool.awaitTermination(7, TimeUnit.DAYS);
    } catch (InterruptedException e) {
        log.error("Interrupted?", e);
    }
    this.masterLearner.cleanupParams(paramVec, sumGradient);
    //WW: renormalize by the total number of queries
    for (Iterator<String> it = sumGradient.keySet().iterator(); it.hasNext(); ) {
        String feature = it.next();
        double unnormf = sumGradient.get(feature);
        // query count stored in numExamplesThisEpoch, as noted above
        double norm = unnormf / this.statistics.numExamplesThisEpoch;
        sumGradient.put(feature, norm);
    }
    return sumGradient;
}
Also used : NamedThreadFactory(edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) ExecutorService(java.util.concurrent.ExecutorService) ThreadPoolExecutor(java.util.concurrent.ThreadPoolExecutor)

Example 5 with SimpleParamVector

use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.

the class SRW method accumulateGradient.

public void accumulateGradient(ParamVector<String, ?> params, PosNegRWExample example, ParamVector<String, ?> accumulator, StatusLogger status) {
    log.debug("Gradient calculating on " + example);
    initializeFeatures(params, example.getGraph());
    ParamVector<String, Double> prepare = new SimpleParamVector<String>();
    regularizer.prepareForExample(params, example.getGraph(), prepare);
    load(params, example);
    inference(params, example, status);
    TIntDoubleMap gradient = gradient(params, example);
    for (Map.Entry<String, Double> e : prepare.entrySet()) {
        if (trainable(e.getKey()))
            accumulator.adjustValue(e.getKey(), -e.getValue() / example.length());
    }
    for (TIntDoubleIterator it = gradient.iterator(); it.hasNext(); ) {
        it.advance();
        String feature = example.getGraph().featureLibrary.getSymbol(it.key());
        if (trainable(feature))
            accumulator.adjustValue(example.getGraph().featureLibrary.getSymbol(it.key()), it.value() / example.length());
    }
}
Also used : TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector) HashMap(java.util.HashMap) Map(java.util.Map) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Aggregations

SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)6 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)2 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)2 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)1 ArrayLearningGraphBuilder (edu.cmu.ml.proppr.graph.ArrayLearningGraphBuilder)1 FixedWeightRules (edu.cmu.ml.proppr.learn.tools.FixedWeightRules)1 CustomConfiguration (edu.cmu.ml.proppr.util.CustomConfiguration)1 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)1 ParsedFile (edu.cmu.ml.proppr.util.ParsedFile)1 SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)1 NamedThreadFactory (edu.cmu.ml.proppr.util.multithreading.NamedThreadFactory)1 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)1 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)1 File (java.io.File)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 ExecutorService (java.util.concurrent.ExecutorService)1 ThreadPoolExecutor (java.util.concurrent.ThreadPoolExecutor)1 CommandLine (org.apache.commons.cli.CommandLine)1 Option (org.apache.commons.cli.Option)1