use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.
the class FiniteDifferenceTest method makeGradient.
public ParamVector<String, ?> makeGradient(SRW srw, ParamVector<String, ?> paramVec, TIntDoubleMap query, int[] pos, int[] neg, ExampleFactory f) {
ParamVector<String, ?> grad = new SimpleParamVector<String>();
srw.accumulateGradient(paramVec, f.makeExample("gradient", brGraph, query, pos, neg), grad, new StatusLogger());
return grad;
}
use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.
the class SRWTest method makeGradient.
// Test removed: We no longer compute rwr in SRW
//
// /**
// * Uniform weights should be the same as the unparameterized basic RWR
// */
// @Test
// public void testUniformRWR() {
// log.debug("Test logging");
// int maxT = 10;
//
// TIntDoubleMap baseLineVec = myRWR(startVec,brGraph,maxT);
// uniformParams.put("id(restart)",srw.getWeightingScheme().defaultWeight());
// TIntDoubleMap newVec = srw.rwrUsingFeatures(brGraph, startVec, uniformParams);
// equalScores(baseLineVec,newVec);
// }
//
// public ParamVector<String,?> makeParams(Map<String,Double> foo) {
// return new SimpleParamVector(foo);
// }
//
// public ParamVector<String,?> makeParams() {
// return new SimpleParamVector();
// }
public ParamVector<String, ?> makeGradient(SRW srw, ParamVector<String, ?> paramVec, TIntDoubleMap query, int[] pos, int[] neg) {
ParamVector<String, ?> grad = new SimpleParamVector<String>();
srw.accumulateGradient(paramVec, factory.makeExample("gradient", brGraph, query, pos, neg), grad, new StatusLogger());
return grad;
}
use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.
the class GradientFinder method main.
public static void main(String[] args) {
try {
int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {
boolean relax;
@Override
protected Option checkOption(Option o) {
if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
o.setRequired(false);
return o;
}
@Override
protected void addCustomOptions(Options options, int[] flags) {
options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
}
@Override
protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
if (groundedFile == null || !groundedFile.exists())
usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
if (gradientFile == null)
usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
// default to 0 epochs
if (!options.hasOption("epochs"))
this.epochs = 0;
this.relax = false;
if (options.hasOption("relaxFW"))
this.relax = true;
}
@Override
public Object getCustomSetting(String name) {
if ("relaxFW".equals(name))
return this.relax;
return null;
}
};
System.out.println(c.toString());
ParamVector<String, ?> params = null;
SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
if (featureIndex.exists()) {
log.info("Reading feature index from " + featureIndex.getName() + "...");
for (String line : new ParsedFile(featureIndex)) {
masterFeatures.insert(line.trim());
}
}
if (c.epochs > 0) {
// train first
log.info("Training for " + c.epochs + " epochs...");
params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
c.initParamsFile, c.epochs);
if (c.paramsFile != null)
ParamsFile.save(params, c.paramsFile, c);
} else if (c.initParamsFile != null) {
params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
} else if (c.paramsFile != null) {
params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
} else {
params = new SimpleParamVector<String>();
}
// this lets prongHorn hold external features fixed for training, but still compute their gradient
if (((Boolean) c.getCustomSetting("relaxFW"))) {
log.info("Turning off fixedWeight rules");
c.trainer.setFixedWeightRules(new FixedWeightRules());
}
ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
ParamsFile.save(batchGradient, c.gradientFile, c);
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.
the class Trainer method findGradient.
public ParamVector<String, ?> findGradient(SymbolTable<String> masterFeatures, Iterable<String> examples, LearningGraphBuilder builder, ParamVector<String, ?> paramVec) {
log.info("Computing gradient on cooked examples...");
ParamVector<String, ?> sumGradient = new SimpleParamVector<String>();
if (paramVec == null) {
paramVec = createParamVector();
}
paramVec = this.masterLearner.setupParams(paramVec);
if (masterFeatures != null && masterFeatures.size() > 0)
LearningGraphBuilder.setFeatures(masterFeatures);
//
// //WW: accumulate example-size normalized gradient
// for (PosNegRWExample x : examples) {
//// this.learner.initializeFeatures(paramVec,x.getGraph());
// this.learner.accumulateGradient(paramVec, x, sumGradient);
// k++;
// }
NamedThreadFactory workThreads = new NamedThreadFactory("work-");
ExecutorService workPool, cleanPool;
workPool = Executors.newFixedThreadPool(this.nthreads, workThreads);
cleanPool = Executors.newSingleThreadExecutor();
// run examples
int id = 1;
int countdown = -1;
Trainer notify = null;
status.start();
for (String s : examples) {
if (log.isInfoEnabled() && status.due())
log.info(id + " examples read...");
long queueSize = (((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount());
if (log.isDebugEnabled())
log.debug("Queue size " + queueSize);
if (countdown > 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown);
countdown--;
} else if (countdown == 0) {
if (log.isDebugEnabled())
log.debug("Countdown " + countdown + "; throttling:");
countdown--;
notify = null;
try {
synchronized (this) {
if (log.isDebugEnabled())
log.debug("Clearing training queue...");
while ((((ThreadPoolExecutor) workPool).getTaskCount() - ((ThreadPoolExecutor) workPool).getCompletedTaskCount()) > this.nthreads) this.wait();
if (log.isDebugEnabled())
log.debug("Queue cleared.");
}
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else if (queueSize > 1.5 * this.nthreads) {
if (log.isDebugEnabled())
log.debug("Starting countdown");
countdown = this.nthreads;
notify = this;
}
Future<PosNegRWExample> parsed = workPool.submit(new Parse(s, builder, id));
Future<ExampleStats> gradfound = workPool.submit(new Grad(parsed, paramVec, sumGradient, id, notify));
cleanPool.submit(new TraceLosses(gradfound, id));
id++;
}
workPool.shutdown();
try {
workPool.awaitTermination(7, TimeUnit.DAYS);
cleanPool.shutdown();
cleanPool.awaitTermination(7, TimeUnit.DAYS);
} catch (InterruptedException e) {
log.error("Interrupted?", e);
}
this.masterLearner.cleanupParams(paramVec, sumGradient);
//WW: renormalize by the total number of queries
for (Iterator<String> it = sumGradient.keySet().iterator(); it.hasNext(); ) {
String feature = it.next();
double unnormf = sumGradient.get(feature);
// query count stored in numExamplesThisEpoch, as noted above
double norm = unnormf / this.statistics.numExamplesThisEpoch;
sumGradient.put(feature, norm);
}
return sumGradient;
}
use of edu.cmu.ml.proppr.util.math.SimpleParamVector in project ProPPR by TeamCohen.
the class SRW method accumulateGradient.
public void accumulateGradient(ParamVector<String, ?> params, PosNegRWExample example, ParamVector<String, ?> accumulator, StatusLogger status) {
log.debug("Gradient calculating on " + example);
initializeFeatures(params, example.getGraph());
ParamVector<String, Double> prepare = new SimpleParamVector<String>();
regularizer.prepareForExample(params, example.getGraph(), prepare);
load(params, example);
inference(params, example, status);
TIntDoubleMap gradient = gradient(params, example);
for (Map.Entry<String, Double> e : prepare.entrySet()) {
if (trainable(e.getKey()))
accumulator.adjustValue(e.getKey(), -e.getValue() / example.length());
}
for (TIntDoubleIterator it = gradient.iterator(); it.hasNext(); ) {
it.advance();
String feature = example.getGraph().featureLibrary.getSymbol(it.key());
if (trainable(feature))
accumulator.adjustValue(example.getGraph().featureLibrary.getSymbol(it.key()), it.value() / example.length());
}
}
Aggregations