use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.
the class SimpleProgramProverTest method testFill.
@Test
public void testFill() throws IOException, LogicProgramException {
WamProgram program = WamBaseProgram.load(new File(PROGRAM));
Query q = new Query(new Goal("coworker", new ConstantArgument("steve"), new ConstantArgument("X")));
System.out.println("Query: " + q.toString());
ProofGraph p = new StateProofGraph(q, apr, program);
Prover prover = new DfsProver(apr);
Map<State, Double> sols = prover.prove(p, new StatusLogger());
// assertEquals(2,sols.size());
HashMap<String, Integer> expected = new HashMap<String, Integer>();
expected.put("steve", 0);
expected.put("sven", 0);
System.out.println("Query: " + q.toString());
for (State s : sols.keySet()) {
if (!s.isCompleted())
continue;
System.out.println(s);
Query a = p.fill(s);
System.out.println(a);
String v = a.getRhs()[0].getArg(1).getName();
System.out.println("Got solution: " + v);
if (expected.containsKey(v))
expected.put(v, expected.get(v) + 1);
}
for (Map.Entry<String, Integer> e : expected.entrySet()) assertEquals(e.getKey(), 1, e.getValue().intValue());
}
use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.
the class UngroundedSolutionsTest method test.
@Test
public void test() throws IOException, LogicProgramException {
WamProgram program = WamBaseProgram.load(new File(PROGRAM));
APROptions apr = new APROptions("depth=20");
WamPlugin facts = LightweightGraphPlugin.load(apr, new File(FACTS));
Query q = new Query(new Goal("grandparent", new ConstantArgument("X"), new ConstantArgument("Y")));
// q.variabilize();
StateProofGraph pg = new StateProofGraph(q, apr, program, facts);
Prover p = new DfsProver(apr);
Map<State, Double> ans = p.prove(pg, new StatusLogger());
// Map<LogicProgramState,Double> ans = p.proveState(program, new ProPPRLogicProgramState(Goal.decompile("grandparent,-1,-2")));
System.out.println("===");
for (State s : ans.keySet()) {
if (s.isCompleted()) {
System.out.println(s);
Map<Argument, String> dict = pg.asDict(s);
System.out.println(Dictionary.buildString(dict, new StringBuilder(), "\n\t").substring(1));
for (String a : dict.values()) {
// a = a.substring(a.indexOf(":"));
assertFalse(a.startsWith("X"));
}
}
}
// System.out.println("===");
// for (String s : Prover.filterSolutions(ans).keySet()) {
// System.out.println(s);
// assertFalse("Filtered solutions contain variables",s.contains("v["));
// }
}
use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.
the class ModuleConfiguration method retrieveSettings.
@Override
protected void retrieveSettings(CommandLine line, int[] allFlags, Options options) throws IOException {
super.retrieveSettings(line, allFlags, options);
int flags;
// modules
flags = modules(allFlags);
if (isOn(flags, USE_PROVER)) {
if (!line.hasOption(PROVER_MODULE_OPTION)) {
// default:
this.prover = new DprProver(apr);
} else {
String[] values = line.getOptionValue(PROVER_MODULE_OPTION).split(":");
boolean proverSupportsPruning = false;
switch(PROVERS.valueOf(values[0])) {
case ippr:
this.prover = new IdPprProver(apr);
break;
case ppr:
this.prover = new PprProver(apr);
break;
case dpr:
this.prover = new DprProver(apr);
break;
case idpr:
this.prover = new IdDprProver(apr);
break;
case p_idpr:
if (prunedPredicateRules == null)
log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " not set");
this.prover = new PruningIdDprProver(apr, prunedPredicateRules);
proverSupportsPruning = true;
break;
case qpr:
this.prover = new PriorityQueueProver(apr);
break;
case pdpr:
this.prover = new PathDprProver(apr);
break;
case dfs:
this.prover = new DfsProver(apr);
break;
case tr:
this.prover = new TracingDfsProver(apr);
if (this.nthreads > 1)
usageOptions(options, allFlags, "Tracing prover is not multithreaded. Remove --threads option or use --threads 1.");
break;
default:
usageOptions(options, allFlags, "No prover definition for '" + values[0] + "'");
}
if (prunedPredicateRules != null && !proverSupportsPruning)
log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " is ignored by this prover");
if (values.length > 1) {
for (int i = 1; i < values.length; i++) {
this.prover.configure(values[i]);
}
}
}
}
if (anyOn(flags, USE_SQUASHFUNCTION | USE_PROVER | USE_SRW)) {
if (!line.hasOption(SQUASHFUNCTION_MODULE_OPTION)) {
// default:
this.squashingFunction = SRW.DEFAULT_SQUASHING_FUNCTION();
} else {
switch(SQUASHFUNCTIONS.valueOf(line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION))) {
case linear:
squashingFunction = new Linear();
break;
case sigmoid:
squashingFunction = new Sigmoid();
break;
case tanh:
squashingFunction = new Tanh();
break;
case tanh1:
squashingFunction = new Tanh1();
break;
case ReLU:
squashingFunction = new ReLU();
break;
case LReLU:
squashingFunction = new LReLU();
break;
case exp:
squashingFunction = new Exp();
break;
case clipExp:
squashingFunction = new ClippedExp();
break;
default:
this.usageOptions(options, allFlags, "Unrecognized squashing function " + line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION));
}
}
}
if (isOn(flags, Configuration.USE_GROUNDER)) {
if (!line.hasOption(GROUNDER_MODULE_OPTION)) {
this.grounder = new Grounder(nthreads, Multithreading.DEFAULT_THROTTLE, apr, prover, program, plugins);
} else {
String[] values = line.getOptionValues(GROUNDER_MODULE_OPTION);
int threads = nthreads;
if (values.length > 1)
threads = Integer.parseInt(values[1]);
int throttle = Multithreading.DEFAULT_THROTTLE;
if (values.length > 2)
throttle = Integer.parseInt(values[2]);
this.grounder = new Grounder(threads, throttle, apr, prover, program, plugins);
}
this.grounder.includeUnlabeledGraphs(includeEmptyGraphs);
}
if (isOn(flags, USE_TRAIN)) {
this.setupSRW(line, flags, options);
seed(line);
if (isOn(flags, USE_TRAINER)) {
// set default stopping criteria
double percent = StoppingCriterion.DEFAULT_MAX_PCT_IMPROVEMENT;
int stableEpochs = StoppingCriterion.DEFAULT_MIN_STABLE_EPOCHS;
TRAINERS type = TRAINERS.cached;
if (line.hasOption(TRAINER_MODULE_OPTION))
type = TRAINERS.valueOf(line.getOptionValues(TRAINER_MODULE_OPTION)[0]);
switch(type) {
case streaming:
this.trainer = new Trainer(this.srw, this.nthreads, this.throttle);
break;
//fallthrough
case caching:
case cached:
boolean shuff = CachingTrainer.DEFAULT_SHUFFLE;
if (line.hasOption(TRAINER_MODULE_OPTION)) {
for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
if (val.startsWith("shuff"))
shuff = Boolean.parseBoolean(val.substring(val.indexOf("=") + 1));
}
}
this.trainer = new CachingTrainer(this.srw, this.nthreads, this.throttle, shuff);
break;
case adagrad:
this.usageOptions(options, allFlags, "Trainer 'adagrad' no longer necessary. Use '--srw adagrad' for adagrad descent method.");
default:
this.usageOptions(options, allFlags, "Unrecognized trainer " + line.getOptionValue(TRAINER_MODULE_OPTION));
}
if (this.srw instanceof AdaGradSRW)
// override default
stableEpochs = 2;
// now get stopping criteria from command line
if (line.hasOption(TRAINER_MODULE_OPTION)) {
for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
if (val.startsWith("pct"))
percent = Double.parseDouble(val.substring(val.indexOf("=") + 1));
else if (val.startsWith("stableEpochs"))
stableEpochs = Integer.parseInt(val.substring(val.indexOf("=") + 1));
}
}
this.trainer.setStoppingCriteria(stableEpochs, percent);
}
}
if (isOn(flags, USE_SRW) && this.srw == null)
this.setupSRW(line, flags, options);
}
use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.
the class SimpleProgramProverTest method test.
@Test
public void test() throws IOException, LogicProgramException {
WamProgram program = WamBaseProgram.load(new File(PROGRAM));
Query q = new Query(new Goal("coworker", new ConstantArgument("steve"), new ConstantArgument("X")));
System.out.println("Query: " + q.toString());
ProofGraph p = new StateProofGraph(q, apr, program);
Prover prover = new DfsProver(apr);
Map<String, Double> sols = prover.solutions(p, new StatusLogger());
assertEquals(2, sols.size());
HashMap<String, Integer> expected = new HashMap<String, Integer>();
expected.put("steve", 0);
expected.put("sven", 0);
System.out.println("Query: " + q.toString());
for (String pair : sols.keySet()) {
System.out.println(pair);
String[] parts = pair.split(":");
String v = parts[1];
System.out.println("Got solution: " + v);
if (expected.containsKey(v))
expected.put(v, expected.get(v) + 1);
}
for (Map.Entry<String, Integer> e : expected.entrySet()) assertEquals(e.getKey(), 1, e.getValue().intValue());
}
use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.
the class PropertiesConfigurationTest method test.
@Test
public void test() {
// config.properties defines train, test, params, prover, queries, force (unary), and two nonexistant options.
System.setProperty(Configuration.PROPFILE, "src/testcases/config.properties");
ModuleConfiguration c = new ModuleConfiguration("--prover dfs".split(" "), 0, Configuration.USE_PARAMS, Configuration.USE_FORCE, Configuration.USE_PROVER | Configuration.USE_SQUASHFUNCTION);
assertTrue("Didn't fetch properties from file", c.paramsFile != null);
assertTrue("Didn't prefer command line properties", c.prover instanceof DfsProver);
assertTrue("Didn't fetch unary argument", c.force);
assertEquals("Didn't fetch apr options properly", 0.01, c.apr.alpha, 1e-10);
}
Aggregations