Search in sources :

Example 1 with DfsProver

use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.

the class SimpleProgramProverTest method testFill.

@Test
public void testFill() throws IOException, LogicProgramException {
    WamProgram program = WamBaseProgram.load(new File(PROGRAM));
    Query q = new Query(new Goal("coworker", new ConstantArgument("steve"), new ConstantArgument("X")));
    System.out.println("Query: " + q.toString());
    ProofGraph p = new StateProofGraph(q, apr, program);
    Prover prover = new DfsProver(apr);
    Map<State, Double> sols = prover.prove(p, new StatusLogger());
    //		assertEquals(2,sols.size());
    HashMap<String, Integer> expected = new HashMap<String, Integer>();
    expected.put("steve", 0);
    expected.put("sven", 0);
    System.out.println("Query: " + q.toString());
    for (State s : sols.keySet()) {
        if (!s.isCompleted())
            continue;
        System.out.println(s);
        Query a = p.fill(s);
        System.out.println(a);
        String v = a.getRhs()[0].getArg(1).getName();
        System.out.println("Got solution: " + v);
        if (expected.containsKey(v))
            expected.put(v, expected.get(v) + 1);
    }
    for (Map.Entry<String, Integer> e : expected.entrySet()) assertEquals(e.getKey(), 1, e.getValue().intValue());
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) HashMap(java.util.HashMap) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Prover(edu.cmu.ml.proppr.prove.Prover) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) ConstantArgument(edu.cmu.ml.proppr.prove.wam.ConstantArgument) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) Goal(edu.cmu.ml.proppr.prove.wam.Goal) State(edu.cmu.ml.proppr.prove.wam.State) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) File(java.io.File) HashMap(java.util.HashMap) Map(java.util.Map) Test(org.junit.Test)

Example 2 with DfsProver

use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.

the class UngroundedSolutionsTest method test.

@Test
public void test() throws IOException, LogicProgramException {
    WamProgram program = WamBaseProgram.load(new File(PROGRAM));
    APROptions apr = new APROptions("depth=20");
    WamPlugin facts = LightweightGraphPlugin.load(apr, new File(FACTS));
    Query q = new Query(new Goal("grandparent", new ConstantArgument("X"), new ConstantArgument("Y")));
    //		q.variabilize();
    StateProofGraph pg = new StateProofGraph(q, apr, program, facts);
    Prover p = new DfsProver(apr);
    Map<State, Double> ans = p.prove(pg, new StatusLogger());
    //		Map<LogicProgramState,Double> ans = p.proveState(program, new ProPPRLogicProgramState(Goal.decompile("grandparent,-1,-2")));
    System.out.println("===");
    for (State s : ans.keySet()) {
        if (s.isCompleted()) {
            System.out.println(s);
            Map<Argument, String> dict = pg.asDict(s);
            System.out.println(Dictionary.buildString(dict, new StringBuilder(), "\n\t").substring(1));
            for (String a : dict.values()) {
                //					a = a.substring(a.indexOf(":"));
                assertFalse(a.startsWith("X"));
            }
        }
    }
//		System.out.println("===");
//		for (String s : Prover.filterSolutions(ans).keySet()) {
//			System.out.println(s);
//			assertFalse("Filtered solutions contain variables",s.contains("v["));
//		}
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) Argument(edu.cmu.ml.proppr.prove.wam.Argument) ConstantArgument(edu.cmu.ml.proppr.prove.wam.ConstantArgument) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Prover(edu.cmu.ml.proppr.prove.Prover) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) ConstantArgument(edu.cmu.ml.proppr.prove.wam.ConstantArgument) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) WamPlugin(edu.cmu.ml.proppr.prove.wam.plugins.WamPlugin) Goal(edu.cmu.ml.proppr.prove.wam.Goal) State(edu.cmu.ml.proppr.prove.wam.State) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) APROptions(edu.cmu.ml.proppr.util.APROptions) File(java.io.File) Test(org.junit.Test)

Example 3 with DfsProver

use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.

the class ModuleConfiguration method retrieveSettings.

@Override
protected void retrieveSettings(CommandLine line, int[] allFlags, Options options) throws IOException {
    super.retrieveSettings(line, allFlags, options);
    int flags;
    // modules
    flags = modules(allFlags);
    if (isOn(flags, USE_PROVER)) {
        if (!line.hasOption(PROVER_MODULE_OPTION)) {
            // default:
            this.prover = new DprProver(apr);
        } else {
            String[] values = line.getOptionValue(PROVER_MODULE_OPTION).split(":");
            boolean proverSupportsPruning = false;
            switch(PROVERS.valueOf(values[0])) {
                case ippr:
                    this.prover = new IdPprProver(apr);
                    break;
                case ppr:
                    this.prover = new PprProver(apr);
                    break;
                case dpr:
                    this.prover = new DprProver(apr);
                    break;
                case idpr:
                    this.prover = new IdDprProver(apr);
                    break;
                case p_idpr:
                    if (prunedPredicateRules == null)
                        log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " not set");
                    this.prover = new PruningIdDprProver(apr, prunedPredicateRules);
                    proverSupportsPruning = true;
                    break;
                case qpr:
                    this.prover = new PriorityQueueProver(apr);
                    break;
                case pdpr:
                    this.prover = new PathDprProver(apr);
                    break;
                case dfs:
                    this.prover = new DfsProver(apr);
                    break;
                case tr:
                    this.prover = new TracingDfsProver(apr);
                    if (this.nthreads > 1)
                        usageOptions(options, allFlags, "Tracing prover is not multithreaded. Remove --threads option or use --threads 1.");
                    break;
                default:
                    usageOptions(options, allFlags, "No prover definition for '" + values[0] + "'");
            }
            if (prunedPredicateRules != null && !proverSupportsPruning)
                log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " is ignored by this prover");
            if (values.length > 1) {
                for (int i = 1; i < values.length; i++) {
                    this.prover.configure(values[i]);
                }
            }
        }
    }
    if (anyOn(flags, USE_SQUASHFUNCTION | USE_PROVER | USE_SRW)) {
        if (!line.hasOption(SQUASHFUNCTION_MODULE_OPTION)) {
            // default:
            this.squashingFunction = SRW.DEFAULT_SQUASHING_FUNCTION();
        } else {
            switch(SQUASHFUNCTIONS.valueOf(line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION))) {
                case linear:
                    squashingFunction = new Linear();
                    break;
                case sigmoid:
                    squashingFunction = new Sigmoid();
                    break;
                case tanh:
                    squashingFunction = new Tanh();
                    break;
                case tanh1:
                    squashingFunction = new Tanh1();
                    break;
                case ReLU:
                    squashingFunction = new ReLU();
                    break;
                case LReLU:
                    squashingFunction = new LReLU();
                    break;
                case exp:
                    squashingFunction = new Exp();
                    break;
                case clipExp:
                    squashingFunction = new ClippedExp();
                    break;
                default:
                    this.usageOptions(options, allFlags, "Unrecognized squashing function " + line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION));
            }
        }
    }
    if (isOn(flags, Configuration.USE_GROUNDER)) {
        if (!line.hasOption(GROUNDER_MODULE_OPTION)) {
            this.grounder = new Grounder(nthreads, Multithreading.DEFAULT_THROTTLE, apr, prover, program, plugins);
        } else {
            String[] values = line.getOptionValues(GROUNDER_MODULE_OPTION);
            int threads = nthreads;
            if (values.length > 1)
                threads = Integer.parseInt(values[1]);
            int throttle = Multithreading.DEFAULT_THROTTLE;
            if (values.length > 2)
                throttle = Integer.parseInt(values[2]);
            this.grounder = new Grounder(threads, throttle, apr, prover, program, plugins);
        }
        this.grounder.includeUnlabeledGraphs(includeEmptyGraphs);
    }
    if (isOn(flags, USE_TRAIN)) {
        this.setupSRW(line, flags, options);
        seed(line);
        if (isOn(flags, USE_TRAINER)) {
            // set default stopping criteria
            double percent = StoppingCriterion.DEFAULT_MAX_PCT_IMPROVEMENT;
            int stableEpochs = StoppingCriterion.DEFAULT_MIN_STABLE_EPOCHS;
            TRAINERS type = TRAINERS.cached;
            if (line.hasOption(TRAINER_MODULE_OPTION))
                type = TRAINERS.valueOf(line.getOptionValues(TRAINER_MODULE_OPTION)[0]);
            switch(type) {
                case streaming:
                    this.trainer = new Trainer(this.srw, this.nthreads, this.throttle);
                    break;
                //fallthrough
                case caching:
                case cached:
                    boolean shuff = CachingTrainer.DEFAULT_SHUFFLE;
                    if (line.hasOption(TRAINER_MODULE_OPTION)) {
                        for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
                            if (val.startsWith("shuff"))
                                shuff = Boolean.parseBoolean(val.substring(val.indexOf("=") + 1));
                        }
                    }
                    this.trainer = new CachingTrainer(this.srw, this.nthreads, this.throttle, shuff);
                    break;
                case adagrad:
                    this.usageOptions(options, allFlags, "Trainer 'adagrad' no longer necessary. Use '--srw adagrad' for adagrad descent method.");
                default:
                    this.usageOptions(options, allFlags, "Unrecognized trainer " + line.getOptionValue(TRAINER_MODULE_OPTION));
            }
            if (this.srw instanceof AdaGradSRW)
                // override default
                stableEpochs = 2;
            // now get stopping criteria from command line
            if (line.hasOption(TRAINER_MODULE_OPTION)) {
                for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
                    if (val.startsWith("pct"))
                        percent = Double.parseDouble(val.substring(val.indexOf("=") + 1));
                    else if (val.startsWith("stableEpochs"))
                        stableEpochs = Integer.parseInt(val.substring(val.indexOf("=") + 1));
                }
            }
            this.trainer.setStoppingCriteria(stableEpochs, percent);
        }
    }
    if (isOn(flags, USE_SRW) && this.srw == null)
        this.setupSRW(line, flags, options);
}
Also used : Tanh1(edu.cmu.ml.proppr.learn.tools.Tanh1) IdDprProver(edu.cmu.ml.proppr.prove.IdDprProver) DprProver(edu.cmu.ml.proppr.prove.DprProver) PathDprProver(edu.cmu.ml.proppr.prove.PathDprProver) PruningIdDprProver(edu.cmu.ml.proppr.prove.PruningIdDprProver) CachingTrainer(edu.cmu.ml.proppr.CachingTrainer) Trainer(edu.cmu.ml.proppr.Trainer) TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Grounder(edu.cmu.ml.proppr.Grounder) Sigmoid(edu.cmu.ml.proppr.learn.tools.Sigmoid) CachingTrainer(edu.cmu.ml.proppr.CachingTrainer) TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) Tanh(edu.cmu.ml.proppr.learn.tools.Tanh) PprProver(edu.cmu.ml.proppr.prove.PprProver) IdPprProver(edu.cmu.ml.proppr.prove.IdPprProver) PriorityQueueProver(edu.cmu.ml.proppr.prove.PriorityQueueProver) Linear(edu.cmu.ml.proppr.learn.tools.Linear) PruningIdDprProver(edu.cmu.ml.proppr.prove.PruningIdDprProver) ClippedExp(edu.cmu.ml.proppr.learn.tools.ClippedExp) PathDprProver(edu.cmu.ml.proppr.prove.PathDprProver) LReLU(edu.cmu.ml.proppr.learn.tools.LReLU) ReLU(edu.cmu.ml.proppr.learn.tools.ReLU) IdDprProver(edu.cmu.ml.proppr.prove.IdDprProver) PruningIdDprProver(edu.cmu.ml.proppr.prove.PruningIdDprProver) ClippedExp(edu.cmu.ml.proppr.learn.tools.ClippedExp) Exp(edu.cmu.ml.proppr.learn.tools.Exp) LReLU(edu.cmu.ml.proppr.learn.tools.LReLU) IdPprProver(edu.cmu.ml.proppr.prove.IdPprProver)

Example 4 with DfsProver

use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.

the class SimpleProgramProverTest method test.

@Test
public void test() throws IOException, LogicProgramException {
    WamProgram program = WamBaseProgram.load(new File(PROGRAM));
    Query q = new Query(new Goal("coworker", new ConstantArgument("steve"), new ConstantArgument("X")));
    System.out.println("Query: " + q.toString());
    ProofGraph p = new StateProofGraph(q, apr, program);
    Prover prover = new DfsProver(apr);
    Map<String, Double> sols = prover.solutions(p, new StatusLogger());
    assertEquals(2, sols.size());
    HashMap<String, Integer> expected = new HashMap<String, Integer>();
    expected.put("steve", 0);
    expected.put("sven", 0);
    System.out.println("Query: " + q.toString());
    for (String pair : sols.keySet()) {
        System.out.println(pair);
        String[] parts = pair.split(":");
        String v = parts[1];
        System.out.println("Got solution: " + v);
        if (expected.containsKey(v))
            expected.put(v, expected.get(v) + 1);
    }
    for (Map.Entry<String, Integer> e : expected.entrySet()) assertEquals(e.getKey(), 1, e.getValue().intValue());
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) HashMap(java.util.HashMap) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Prover(edu.cmu.ml.proppr.prove.Prover) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) ConstantArgument(edu.cmu.ml.proppr.prove.wam.ConstantArgument) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) Goal(edu.cmu.ml.proppr.prove.wam.Goal) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) File(java.io.File) HashMap(java.util.HashMap) Map(java.util.Map) Test(org.junit.Test)

Example 5 with DfsProver

use of edu.cmu.ml.proppr.prove.DfsProver in project ProPPR by TeamCohen.

the class PropertiesConfigurationTest method test.

@Test
public void test() {
    // config.properties defines train, test, params, prover, queries, force (unary), and two nonexistant options.
    System.setProperty(Configuration.PROPFILE, "src/testcases/config.properties");
    ModuleConfiguration c = new ModuleConfiguration("--prover dfs".split(" "), 0, Configuration.USE_PARAMS, Configuration.USE_FORCE, Configuration.USE_PROVER | Configuration.USE_SQUASHFUNCTION);
    assertTrue("Didn't fetch properties from file", c.paramsFile != null);
    assertTrue("Didn't prefer command line properties", c.prover instanceof DfsProver);
    assertTrue("Didn't fetch unary argument", c.force);
    assertEquals("Didn't fetch apr options properly", 0.01, c.apr.alpha, 1e-10);
}
Also used : ModuleConfiguration(edu.cmu.ml.proppr.util.ModuleConfiguration) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Test(org.junit.Test)

Aggregations

DfsProver (edu.cmu.ml.proppr.prove.DfsProver)5 Test (org.junit.Test)4 Prover (edu.cmu.ml.proppr.prove.Prover)3 ConstantArgument (edu.cmu.ml.proppr.prove.wam.ConstantArgument)3 Goal (edu.cmu.ml.proppr.prove.wam.Goal)3 Query (edu.cmu.ml.proppr.prove.wam.Query)3 StateProofGraph (edu.cmu.ml.proppr.prove.wam.StateProofGraph)3 WamProgram (edu.cmu.ml.proppr.prove.wam.WamProgram)3 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)3 File (java.io.File)3 ProofGraph (edu.cmu.ml.proppr.prove.wam.ProofGraph)2 State (edu.cmu.ml.proppr.prove.wam.State)2 HashMap (java.util.HashMap)2 Map (java.util.Map)2 CachingTrainer (edu.cmu.ml.proppr.CachingTrainer)1 Grounder (edu.cmu.ml.proppr.Grounder)1 Trainer (edu.cmu.ml.proppr.Trainer)1 ClippedExp (edu.cmu.ml.proppr.learn.tools.ClippedExp)1 Exp (edu.cmu.ml.proppr.learn.tools.Exp)1 LReLU (edu.cmu.ml.proppr.learn.tools.LReLU)1