Search in sources :

Example 1 with DprProver

use of edu.cmu.ml.proppr.prove.DprProver in project ProPPR by TeamCohen.

the class WeightedEdgeTest method testOne.

public void testOne(APROptions apr, WamPlugin plug) throws IOException, LogicProgramException {
    Prover p = new DprProver(apr);
    WamProgram program = WamBaseProgram.load(RULES);
    WamPlugin[] plugins = new WamPlugin[] { plug };
    Grounder grounder = new Grounder(apr, p, program, plugins);
    assertTrue("Missing weighted functor", plugins[0].claim("hasWord#/3"));
    Query query = Query.parse("words(p1,W)");
    ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("words(p1,good)") }, new Query[] { Query.parse("words(p1,thing)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    //			Map<String,Double> m = p.solutions(pg);
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
    GroundedExample ex = grounder.groundExample(p, pg);
    String serialized = ex.getGraph().serialize(true).replaceAll("\t", "\n");
    //String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    System.out.println(serialized);
    assertTrue("Label weights must appear in ground graph (0.9)", serialized.indexOf("0.9") >= 0);
    assertTrue("Label weights must appear in ground graph (0.1)", serialized.indexOf("0.1") >= 0);
    //			Map<String,Double> m = p.solvedQueries(pg);
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n"));
    Query query2 = Query.parse("words2(p1,W)");
    ProofGraph pg2 = new StateProofGraph(new InferenceExample(query2, new Query[] { Query.parse("words(p1,good)") }, new Query[] { Query.parse("words(p1,thing)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    //			Map<String,Double> m = p.solutions(pg);
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
    GroundedExample ex2 = grounder.groundExample(p, pg2);
    String serialized2 = ex2.getGraph().serialize(true).replaceAll("\t", "\n");
    //String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    System.out.println(serialized2);
    assertTrue("Label weights must appear in ground graph (0.9)", serialized2.indexOf("0.9") >= 0);
    assertTrue("Label weights must appear in ground graph (0.1)", serialized2.indexOf("0.1") >= 0);
}
Also used : GroundedExample(edu.cmu.ml.proppr.examples.GroundedExample) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) DprProver(edu.cmu.ml.proppr.prove.DprProver) Prover(edu.cmu.ml.proppr.prove.Prover) DprProver(edu.cmu.ml.proppr.prove.DprProver) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) Feature(edu.cmu.ml.proppr.prove.wam.Feature) InferenceExample(edu.cmu.ml.proppr.examples.InferenceExample) Grounder(edu.cmu.ml.proppr.Grounder)

Example 2 with DprProver

use of edu.cmu.ml.proppr.prove.DprProver in project ProPPR by TeamCohen.

the class WeightedFeaturesTest method testAsGraph.

@Test
public void testAsGraph() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    Prover p = new DprProver(apr);
    WamProgram program = WamBaseProgram.load(RULES);
    WamPlugin[] plugins = new WamPlugin[] { FactsPlugin.load(apr, LABELS, false), LightweightGraphPlugin.load(apr, WORDSGRAPH, -1) };
    Grounder grounder = new Grounder(apr, p, program, plugins);
    assertTrue(plugins[1].claim("hasWord#/3"));
    Query query = Query.parse("predict(p1,Y)");
    ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("predict(p1,pos)") }, new Query[] { Query.parse("predict(p1,neg)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    GroundedExample ex = grounder.groundExample(p, pg);
    String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    System.out.println(serialized);
    // hack
    assertTrue("Word weights must appear in ground graph", serialized.indexOf("0.9") > 0);
    assertTrue("Word weights must appear in ground graph", serialized.indexOf("0.1") > 0);
}
Also used : GroundedExample(edu.cmu.ml.proppr.examples.GroundedExample) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) DprProver(edu.cmu.ml.proppr.prove.DprProver) Prover(edu.cmu.ml.proppr.prove.Prover) DprProver(edu.cmu.ml.proppr.prove.DprProver) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) Feature(edu.cmu.ml.proppr.prove.wam.Feature) InferenceExample(edu.cmu.ml.proppr.examples.InferenceExample) APROptions(edu.cmu.ml.proppr.util.APROptions) Grounder(edu.cmu.ml.proppr.Grounder) Test(org.junit.Test)

Example 3 with DprProver

use of edu.cmu.ml.proppr.prove.DprProver in project ProPPR by TeamCohen.

the class BoundVariableGraphTest method test.

@Test
public void test() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    Prover p = new DprProver(apr);
    //			Prover p = new TracingDfsProver(apr);
    WamProgram program = WamBaseProgram.load(RULES);
    WamPlugin[] plugins = new WamPlugin[] { LightweightGraphPlugin.load(apr, GRAPH) };
    Grounder grounder = new Grounder(apr, p, program, plugins);
    Query query = Query.parse("hasWord(p1,good)");
    ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("hasWord(p1,good)") }, new Query[0]), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    //			Map<String,Double> m = p.solutions(pg);
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
    GroundedExample ex = grounder.groundExample(p, pg);
    ex.getGraph().serialize();
    String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    System.out.println(serialized);
    assertEquals("Too many edges", 4, ex.getGraph().edgeSize());
//			Map<String,Double> m = p.solvedQueries(pg);
//			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n"));
}
Also used : GroundedExample(edu.cmu.ml.proppr.examples.GroundedExample) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) DprProver(edu.cmu.ml.proppr.prove.DprProver) Prover(edu.cmu.ml.proppr.prove.Prover) TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) DprProver(edu.cmu.ml.proppr.prove.DprProver) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) Feature(edu.cmu.ml.proppr.prove.wam.Feature) InferenceExample(edu.cmu.ml.proppr.examples.InferenceExample) APROptions(edu.cmu.ml.proppr.util.APROptions) Grounder(edu.cmu.ml.proppr.Grounder) Test(org.junit.Test)

Example 4 with DprProver

use of edu.cmu.ml.proppr.prove.DprProver in project ProPPR by TeamCohen.

the class TestNeqPlugin method test.

@Test
public void test() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    WamProgram program = WamProgram.load(new File(PROGRAM));
    Query different = Query.parse("different(door,cat)");
    Query same = Query.parse("different(lake,lake)");
    Prover p = new DprProver(apr);
    StatusLogger s = new StatusLogger();
    assertEquals("different should have 1 solution", 1, p.solutions(new StateProofGraph(different, apr, program), s).size());
    assertEquals("same should have no solution", 0, p.solutions(new StateProofGraph(same, apr, program), s).size());
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) DprProver(edu.cmu.ml.proppr.prove.DprProver) Prover(edu.cmu.ml.proppr.prove.Prover) DprProver(edu.cmu.ml.proppr.prove.DprProver) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) APROptions(edu.cmu.ml.proppr.util.APROptions) File(java.io.File) Test(org.junit.Test)

Example 5 with DprProver

use of edu.cmu.ml.proppr.prove.DprProver in project ProPPR by TeamCohen.

the class ModuleConfiguration method retrieveSettings.

@Override
protected void retrieveSettings(CommandLine line, int[] allFlags, Options options) throws IOException {
    super.retrieveSettings(line, allFlags, options);
    int flags;
    // modules
    flags = modules(allFlags);
    if (isOn(flags, USE_PROVER)) {
        if (!line.hasOption(PROVER_MODULE_OPTION)) {
            // default:
            this.prover = new DprProver(apr);
        } else {
            String[] values = line.getOptionValue(PROVER_MODULE_OPTION).split(":");
            boolean proverSupportsPruning = false;
            switch(PROVERS.valueOf(values[0])) {
                case ippr:
                    this.prover = new IdPprProver(apr);
                    break;
                case ppr:
                    this.prover = new PprProver(apr);
                    break;
                case dpr:
                    this.prover = new DprProver(apr);
                    break;
                case idpr:
                    this.prover = new IdDprProver(apr);
                    break;
                case p_idpr:
                    if (prunedPredicateRules == null)
                        log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " not set");
                    this.prover = new PruningIdDprProver(apr, prunedPredicateRules);
                    proverSupportsPruning = true;
                    break;
                case qpr:
                    this.prover = new PriorityQueueProver(apr);
                    break;
                case pdpr:
                    this.prover = new PathDprProver(apr);
                    break;
                case dfs:
                    this.prover = new DfsProver(apr);
                    break;
                case tr:
                    this.prover = new TracingDfsProver(apr);
                    if (this.nthreads > 1)
                        usageOptions(options, allFlags, "Tracing prover is not multithreaded. Remove --threads option or use --threads 1.");
                    break;
                default:
                    usageOptions(options, allFlags, "No prover definition for '" + values[0] + "'");
            }
            if (prunedPredicateRules != null && !proverSupportsPruning)
                log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " is ignored by this prover");
            if (values.length > 1) {
                for (int i = 1; i < values.length; i++) {
                    this.prover.configure(values[i]);
                }
            }
        }
    }
    if (anyOn(flags, USE_SQUASHFUNCTION | USE_PROVER | USE_SRW)) {
        if (!line.hasOption(SQUASHFUNCTION_MODULE_OPTION)) {
            // default:
            this.squashingFunction = SRW.DEFAULT_SQUASHING_FUNCTION();
        } else {
            switch(SQUASHFUNCTIONS.valueOf(line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION))) {
                case linear:
                    squashingFunction = new Linear();
                    break;
                case sigmoid:
                    squashingFunction = new Sigmoid();
                    break;
                case tanh:
                    squashingFunction = new Tanh();
                    break;
                case tanh1:
                    squashingFunction = new Tanh1();
                    break;
                case ReLU:
                    squashingFunction = new ReLU();
                    break;
                case LReLU:
                    squashingFunction = new LReLU();
                    break;
                case exp:
                    squashingFunction = new Exp();
                    break;
                case clipExp:
                    squashingFunction = new ClippedExp();
                    break;
                default:
                    this.usageOptions(options, allFlags, "Unrecognized squashing function " + line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION));
            }
        }
    }
    if (isOn(flags, Configuration.USE_GROUNDER)) {
        if (!line.hasOption(GROUNDER_MODULE_OPTION)) {
            this.grounder = new Grounder(nthreads, Multithreading.DEFAULT_THROTTLE, apr, prover, program, plugins);
        } else {
            String[] values = line.getOptionValues(GROUNDER_MODULE_OPTION);
            int threads = nthreads;
            if (values.length > 1)
                threads = Integer.parseInt(values[1]);
            int throttle = Multithreading.DEFAULT_THROTTLE;
            if (values.length > 2)
                throttle = Integer.parseInt(values[2]);
            this.grounder = new Grounder(threads, throttle, apr, prover, program, plugins);
        }
        this.grounder.includeUnlabeledGraphs(includeEmptyGraphs);
    }
    if (isOn(flags, USE_TRAIN)) {
        this.setupSRW(line, flags, options);
        seed(line);
        if (isOn(flags, USE_TRAINER)) {
            // set default stopping criteria
            double percent = StoppingCriterion.DEFAULT_MAX_PCT_IMPROVEMENT;
            int stableEpochs = StoppingCriterion.DEFAULT_MIN_STABLE_EPOCHS;
            TRAINERS type = TRAINERS.cached;
            if (line.hasOption(TRAINER_MODULE_OPTION))
                type = TRAINERS.valueOf(line.getOptionValues(TRAINER_MODULE_OPTION)[0]);
            switch(type) {
                case streaming:
                    this.trainer = new Trainer(this.srw, this.nthreads, this.throttle);
                    break;
                //fallthrough
                case caching:
                case cached:
                    boolean shuff = CachingTrainer.DEFAULT_SHUFFLE;
                    if (line.hasOption(TRAINER_MODULE_OPTION)) {
                        for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
                            if (val.startsWith("shuff"))
                                shuff = Boolean.parseBoolean(val.substring(val.indexOf("=") + 1));
                        }
                    }
                    this.trainer = new CachingTrainer(this.srw, this.nthreads, this.throttle, shuff);
                    break;
                case adagrad:
                    this.usageOptions(options, allFlags, "Trainer 'adagrad' no longer necessary. Use '--srw adagrad' for adagrad descent method.");
                default:
                    this.usageOptions(options, allFlags, "Unrecognized trainer " + line.getOptionValue(TRAINER_MODULE_OPTION));
            }
            if (this.srw instanceof AdaGradSRW)
                // override default
                stableEpochs = 2;
            // now get stopping criteria from command line
            if (line.hasOption(TRAINER_MODULE_OPTION)) {
                for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
                    if (val.startsWith("pct"))
                        percent = Double.parseDouble(val.substring(val.indexOf("=") + 1));
                    else if (val.startsWith("stableEpochs"))
                        stableEpochs = Integer.parseInt(val.substring(val.indexOf("=") + 1));
                }
            }
            this.trainer.setStoppingCriteria(stableEpochs, percent);
        }
    }
    if (isOn(flags, USE_SRW) && this.srw == null)
        this.setupSRW(line, flags, options);
}
Also used : Tanh1(edu.cmu.ml.proppr.learn.tools.Tanh1) IdDprProver(edu.cmu.ml.proppr.prove.IdDprProver) DprProver(edu.cmu.ml.proppr.prove.DprProver) PathDprProver(edu.cmu.ml.proppr.prove.PathDprProver) PruningIdDprProver(edu.cmu.ml.proppr.prove.PruningIdDprProver) CachingTrainer(edu.cmu.ml.proppr.CachingTrainer) Trainer(edu.cmu.ml.proppr.Trainer) TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) DfsProver(edu.cmu.ml.proppr.prove.DfsProver) Grounder(edu.cmu.ml.proppr.Grounder) Sigmoid(edu.cmu.ml.proppr.learn.tools.Sigmoid) CachingTrainer(edu.cmu.ml.proppr.CachingTrainer) TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) Tanh(edu.cmu.ml.proppr.learn.tools.Tanh) PprProver(edu.cmu.ml.proppr.prove.PprProver) IdPprProver(edu.cmu.ml.proppr.prove.IdPprProver) PriorityQueueProver(edu.cmu.ml.proppr.prove.PriorityQueueProver) Linear(edu.cmu.ml.proppr.learn.tools.Linear) PruningIdDprProver(edu.cmu.ml.proppr.prove.PruningIdDprProver) ClippedExp(edu.cmu.ml.proppr.learn.tools.ClippedExp) PathDprProver(edu.cmu.ml.proppr.prove.PathDprProver) LReLU(edu.cmu.ml.proppr.learn.tools.LReLU) ReLU(edu.cmu.ml.proppr.learn.tools.ReLU) IdDprProver(edu.cmu.ml.proppr.prove.IdDprProver) PruningIdDprProver(edu.cmu.ml.proppr.prove.PruningIdDprProver) ClippedExp(edu.cmu.ml.proppr.learn.tools.ClippedExp) Exp(edu.cmu.ml.proppr.learn.tools.Exp) LReLU(edu.cmu.ml.proppr.learn.tools.LReLU) IdPprProver(edu.cmu.ml.proppr.prove.IdPprProver)

Aggregations

DprProver (edu.cmu.ml.proppr.prove.DprProver)14 Prover (edu.cmu.ml.proppr.prove.Prover)11 WamProgram (edu.cmu.ml.proppr.prove.wam.WamProgram)11 APROptions (edu.cmu.ml.proppr.util.APROptions)11 Query (edu.cmu.ml.proppr.prove.wam.Query)10 StateProofGraph (edu.cmu.ml.proppr.prove.wam.StateProofGraph)10 Test (org.junit.Test)10 GroundedExample (edu.cmu.ml.proppr.examples.GroundedExample)8 InferenceExample (edu.cmu.ml.proppr.examples.InferenceExample)8 ProofGraph (edu.cmu.ml.proppr.prove.wam.ProofGraph)8 Grounder (edu.cmu.ml.proppr.Grounder)6 Feature (edu.cmu.ml.proppr.prove.wam.Feature)6 File (java.io.File)5 IdDprProver (edu.cmu.ml.proppr.prove.IdDprProver)4 IdPprProver (edu.cmu.ml.proppr.prove.IdPprProver)4 PprProver (edu.cmu.ml.proppr.prove.PprProver)4 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)4 WamPlugin (edu.cmu.ml.proppr.prove.wam.plugins.WamPlugin)3 TracingDfsProver (edu.cmu.ml.proppr.prove.TracingDfsProver)2 CachingTrainer (edu.cmu.ml.proppr.CachingTrainer)1