use of edu.cmu.ml.proppr.util.SimpleSymbolTable in project ProPPR by TeamCohen.
the class GradientFinder method main.
public static void main(String[] args) {
try {
int inputFiles = Configuration.USE_GROUNDED | Configuration.USE_INIT_PARAMS;
int outputFiles = Configuration.USE_GRADIENT | Configuration.USE_PARAMS;
int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
int constants = Configuration.USE_THREADS | Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_FIXEDWEIGHTS;
CustomConfiguration c = new CustomConfiguration(args, inputFiles, outputFiles, constants, modules) {
boolean relax;
@Override
protected Option checkOption(Option o) {
if (PARAMS_FILE_OPTION.equals(o.getLongOpt()) || INIT_PARAMS_FILE_OPTION.equals(o.getLongOpt()))
o.setRequired(false);
return o;
}
@Override
protected void addCustomOptions(Options options, int[] flags) {
options.addOption(Option.builder().longOpt("relaxFW").desc("Relax fixedWeight rules for gradient computation (used in ProngHorn)").optionalArg(true).build());
}
@Override
protected void retrieveCustomSettings(CommandLine line, int[] flags, Options options) {
if (groundedFile == null || !groundedFile.exists())
usageOptions(options, flags, "Must specify grounded file using --" + Configuration.GROUNDED_FILE_OPTION);
if (gradientFile == null)
usageOptions(options, flags, "Must specify gradient using --" + Configuration.GRADIENT_FILE_OPTION);
// default to 0 epochs
if (!options.hasOption("epochs"))
this.epochs = 0;
this.relax = false;
if (options.hasOption("relaxFW"))
this.relax = true;
}
@Override
public Object getCustomSetting(String name) {
if ("relaxFW".equals(name))
return this.relax;
return null;
}
};
System.out.println(c.toString());
ParamVector<String, ?> params = null;
SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
File featureIndex = new File(c.groundedFile.getParent(), c.groundedFile.getName() + Grounder.FEATURE_INDEX_EXTENSION);
if (featureIndex.exists()) {
log.info("Reading feature index from " + featureIndex.getName() + "...");
for (String line : new ParsedFile(featureIndex)) {
masterFeatures.insert(line.trim());
}
}
if (c.epochs > 0) {
// train first
log.info("Training for " + c.epochs + " epochs...");
params = c.trainer.train(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), // create a parameter vector
c.initParamsFile, c.epochs);
if (c.paramsFile != null)
ParamsFile.save(params, c.paramsFile, c);
} else if (c.initParamsFile != null) {
params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.initParamsFile)));
} else if (c.paramsFile != null) {
params = new SimpleParamVector<String>(Dictionary.load(new ParsedFile(c.paramsFile)));
} else {
params = new SimpleParamVector<String>();
}
// this lets prongHorn hold external features fixed for training, but still compute their gradient
if (((Boolean) c.getCustomSetting("relaxFW"))) {
log.info("Turning off fixedWeight rules");
c.trainer.setFixedWeightRules(new FixedWeightRules());
}
ParamVector<String, ?> batchGradient = c.trainer.findGradient(masterFeatures, new ParsedFile(c.groundedFile), new ArrayLearningGraphBuilder(), params);
ParamsFile.save(batchGradient, c.gradientFile, c);
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
use of edu.cmu.ml.proppr.util.SimpleSymbolTable in project ProPPR by TeamCohen.
the class ProverTestTemplate method testProveState.
@Test
public void testProveState() throws LogicProgramException {
log.info("testProveState");
FeatureDictWeighter w = new InnerProductWeighter();
SymbolTable<Feature> featureTab = new SimpleSymbolTable<Feature>();
int milk = featureTab.getId(new Feature("milk"));
w.put(featureTab.getSymbol(milk), 2);
prover.setWeighter(w);
ProofGraph pg = prover.makeProofGraph(new InferenceExample(Query.parse("isa(elsie,X)"), null, null), apr, featureTab, lpMilk, fMilk);
//("isa","elsie","X"));
Map<State, Double> dist = prover.prove(pg, new StatusLogger());
double query = 0.0;
double platypus = 0.0;
double others = 0.0;
double all = 0.0;
for (Map.Entry<State, Double> s : dist.entrySet()) {
Query q = pg.fill(s.getKey());
String arg2 = q.getRhs()[0].getArg(1).getName();
if ("platypus".equals(arg2)) {
platypus = Math.max(platypus, s.getValue());
} else if ("X1".equals(arg2)) {
query = Math.max(query, s.getValue());
} else {
others = Math.max(others, s.getValue());
}
System.out.println(q + "\t" + s.getValue());
all += s.getValue();
}
System.out.println();
System.out.println("query weight: " + query);
System.out.println("platypus weight: " + platypus);
System.out.println("others weight: " + others);
// assertTrue("query should retain most weight",query > Math.max(platypus,others));
assertTrue("milk-featured paths should score higher than others", platypus > others);
assertEquals("Total weight of all states should be around 1.0", 1.0, all, 10 * this.apr.epsilon);
assertEquals("Known features", 1, prover.weighter.numKnownFeatures);
assertEquals("Unknown features", 5, prover.weighter.numUnknownFeatures);
}
use of edu.cmu.ml.proppr.util.SimpleSymbolTable in project ProPPR by TeamCohen.
the class Trainer method main.
public static void main(String[] args) {
try {
int inputFiles = Configuration.USE_TRAIN | Configuration.USE_INIT_PARAMS;
int outputFiles = Configuration.USE_PARAMS;
int constants = Configuration.USE_EPOCHS | Configuration.USE_FORCE | Configuration.USE_THREADS | Configuration.USE_FIXEDWEIGHTS;
int modules = Configuration.USE_TRAINER | Configuration.USE_SRW | Configuration.USE_SQUASHFUNCTION;
ModuleConfiguration c = new ModuleConfiguration(args, inputFiles, outputFiles, constants, modules);
log.info(c.toString());
String groundedFile = c.queryFile.getPath();
if (!c.queryFile.getName().endsWith(Grounder.GROUNDED_SUFFIX)) {
throw new IllegalStateException("Run Grounder on " + c.queryFile.getName() + " first. Ground+Train in one go is not supported yet.");
}
SymbolTable<String> masterFeatures = new SimpleSymbolTable<String>();
File featureIndex = new File(groundedFile + Grounder.FEATURE_INDEX_EXTENSION);
if (featureIndex.exists()) {
log.info("Reading feature index from " + featureIndex.getName() + "...");
for (String line : new ParsedFile(featureIndex)) {
masterFeatures.insert(line.trim());
}
}
log.info("Training model parameters on " + groundedFile + "...");
long start = System.currentTimeMillis();
ParamVector<String, ?> params = c.trainer.train(masterFeatures, new ParsedFile(groundedFile), new ArrayLearningGraphBuilder(), c.initParamsFile, c.epochs);
System.out.println("Training time: " + (System.currentTimeMillis() - start));
if (c.paramsFile != null) {
log.info("Saving parameters to " + c.paramsFile + "...");
ParamsFile.save(params, c.paramsFile, c);
}
} catch (Throwable t) {
t.printStackTrace();
System.exit(-1);
}
}
Aggregations