Search in sources :

Example 1 with DprProver

use of in project ProPPR by TeamCohen.

the class WeightedEdgeTest method testOne.

public void testOne(APROptions apr, WamPlugin plug) throws IOException, LogicProgramException {
    Prover p = new DprProver(apr);
    WamProgram program = WamBaseProgram.load(RULES);
    WamPlugin[] plugins = new WamPlugin[] { plug };
    Grounder grounder = new Grounder(apr, p, program, plugins);
    assertTrue("Missing weighted functor", plugins[0].claim("hasWord#/3"));
    Query query = Query.parse("words(p1,W)");
    ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("words(p1,good)") }, new Query[] { Query.parse("words(p1,thing)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    //			Map<String,Double> m =;
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
    GroundedExample ex = grounder.groundExample(p, pg);
    String serialized = ex.getGraph().serialize(true).replaceAll("\t", "\n");
    //String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    assertTrue("Label weights must appear in ground graph (0.9)", serialized.indexOf("0.9") >= 0);
    assertTrue("Label weights must appear in ground graph (0.1)", serialized.indexOf("0.1") >= 0);
    //			Map<String,Double> m = p.solvedQueries(pg);
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n"));
    Query query2 = Query.parse("words2(p1,W)");
    ProofGraph pg2 = new StateProofGraph(new InferenceExample(query2, new Query[] { Query.parse("words(p1,good)") }, new Query[] { Query.parse("words(p1,thing)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    //			Map<String,Double> m =;
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
    GroundedExample ex2 = grounder.groundExample(p, pg2);
    String serialized2 = ex2.getGraph().serialize(true).replaceAll("\t", "\n");
    //String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    assertTrue("Label weights must appear in ground graph (0.9)", serialized2.indexOf("0.9") >= 0);
    assertTrue("Label weights must appear in ground graph (0.1)", serialized2.indexOf("0.1") >= 0);
Also used : GroundedExample( Query( StateProofGraph( ProofGraph( DprProver( Prover( DprProver( WamProgram( StateProofGraph( Feature( InferenceExample( Grounder(

Example 2 with DprProver

use of in project ProPPR by TeamCohen.

the class WeightedFeaturesTest method testAsGraph.

public void testAsGraph() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    Prover p = new DprProver(apr);
    WamProgram program = WamBaseProgram.load(RULES);
    WamPlugin[] plugins = new WamPlugin[] { FactsPlugin.load(apr, LABELS, false), LightweightGraphPlugin.load(apr, WORDSGRAPH, -1) };
    Grounder grounder = new Grounder(apr, p, program, plugins);
    Query query = Query.parse("predict(p1,Y)");
    ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("predict(p1,pos)") }, new Query[] { Query.parse("predict(p1,neg)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    GroundedExample ex = grounder.groundExample(p, pg);
    String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    // hack
    assertTrue("Word weights must appear in ground graph", serialized.indexOf("0.9") > 0);
    assertTrue("Word weights must appear in ground graph", serialized.indexOf("0.1") > 0);
Also used : GroundedExample( Query( StateProofGraph( ProofGraph( DprProver( Prover( DprProver( WamProgram( StateProofGraph( Feature( InferenceExample( APROptions( Grounder( Test(org.junit.Test)

Example 3 with DprProver

use of in project ProPPR by TeamCohen.

the class BoundVariableGraphTest method test.

public void test() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    Prover p = new DprProver(apr);
    //			Prover p = new TracingDfsProver(apr);
    WamProgram program = WamBaseProgram.load(RULES);
    WamPlugin[] plugins = new WamPlugin[] { LightweightGraphPlugin.load(apr, GRAPH) };
    Grounder grounder = new Grounder(apr, p, program, plugins);
    Query query = Query.parse("hasWord(p1,good)");
    ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("hasWord(p1,good)") }, new Query[0]), apr, new SimpleSymbolTable<Feature>(), program, plugins);
    //			Map<String,Double> m =;
    //			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
    GroundedExample ex = grounder.groundExample(p, pg);
    String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
    assertEquals("Too many edges", 4, ex.getGraph().edgeSize());
//			Map<String,Double> m = p.solvedQueries(pg);
//			System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n"));
Also used : GroundedExample( Query( StateProofGraph( ProofGraph( DprProver( Prover( TracingDfsProver( DprProver( WamProgram( StateProofGraph( Feature( InferenceExample( APROptions( Grounder( Test(org.junit.Test)

Example 4 with DprProver

use of in project ProPPR by TeamCohen.

the class TestNeqPlugin method test.

public void test() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    WamProgram program = WamProgram.load(new File(PROGRAM));
    Query different = Query.parse("different(door,cat)");
    Query same = Query.parse("different(lake,lake)");
    Prover p = new DprProver(apr);
    StatusLogger s = new StatusLogger();
    assertEquals("different should have 1 solution", 1, StateProofGraph(different, apr, program), s).size());
    assertEquals("same should have no solution", 0, StateProofGraph(same, apr, program), s).size());
Also used : StatusLogger( Query( DprProver( Prover( DprProver( WamProgram( StateProofGraph( APROptions( File( Test(org.junit.Test)

Example 5 with DprProver

use of in project ProPPR by TeamCohen.

the class ModuleConfiguration method retrieveSettings.

protected void retrieveSettings(CommandLine line, int[] allFlags, Options options) throws IOException {
    super.retrieveSettings(line, allFlags, options);
    int flags;
    // modules
    flags = modules(allFlags);
    if (isOn(flags, USE_PROVER)) {
        if (!line.hasOption(PROVER_MODULE_OPTION)) {
            // default:
            this.prover = new DprProver(apr);
        } else {
            String[] values = line.getOptionValue(PROVER_MODULE_OPTION).split(":");
            boolean proverSupportsPruning = false;
            switch(PROVERS.valueOf(values[0])) {
                case ippr:
                    this.prover = new IdPprProver(apr);
                case ppr:
                    this.prover = new PprProver(apr);
                case dpr:
                    this.prover = new DprProver(apr);
                case idpr:
                    this.prover = new IdDprProver(apr);
                case p_idpr:
                    if (prunedPredicateRules == null)
                        log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " not set");
                    this.prover = new PruningIdDprProver(apr, prunedPredicateRules);
                    proverSupportsPruning = true;
                case qpr:
                    this.prover = new PriorityQueueProver(apr);
                case pdpr:
                    this.prover = new PathDprProver(apr);
                case dfs:
                    this.prover = new DfsProver(apr);
                case tr:
                    this.prover = new TracingDfsProver(apr);
                    if (this.nthreads > 1)
                        usageOptions(options, allFlags, "Tracing prover is not multithreaded. Remove --threads option or use --threads 1.");
                    usageOptions(options, allFlags, "No prover definition for '" + values[0] + "'");
            if (prunedPredicateRules != null && !proverSupportsPruning)
                log.warn("option --" + PRUNEDPREDICATE_CONST_OPTION + " is ignored by this prover");
            if (values.length > 1) {
                for (int i = 1; i < values.length; i++) {
    if (anyOn(flags, USE_SQUASHFUNCTION | USE_PROVER | USE_SRW)) {
        if (!line.hasOption(SQUASHFUNCTION_MODULE_OPTION)) {
            // default:
            this.squashingFunction = SRW.DEFAULT_SQUASHING_FUNCTION();
        } else {
            switch(SQUASHFUNCTIONS.valueOf(line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION))) {
                case linear:
                    squashingFunction = new Linear();
                case sigmoid:
                    squashingFunction = new Sigmoid();
                case tanh:
                    squashingFunction = new Tanh();
                case tanh1:
                    squashingFunction = new Tanh1();
                case ReLU:
                    squashingFunction = new ReLU();
                case LReLU:
                    squashingFunction = new LReLU();
                case exp:
                    squashingFunction = new Exp();
                case clipExp:
                    squashingFunction = new ClippedExp();
                    this.usageOptions(options, allFlags, "Unrecognized squashing function " + line.getOptionValue(SQUASHFUNCTION_MODULE_OPTION));
    if (isOn(flags, Configuration.USE_GROUNDER)) {
        if (!line.hasOption(GROUNDER_MODULE_OPTION)) {
            this.grounder = new Grounder(nthreads, Multithreading.DEFAULT_THROTTLE, apr, prover, program, plugins);
        } else {
            String[] values = line.getOptionValues(GROUNDER_MODULE_OPTION);
            int threads = nthreads;
            if (values.length > 1)
                threads = Integer.parseInt(values[1]);
            int throttle = Multithreading.DEFAULT_THROTTLE;
            if (values.length > 2)
                throttle = Integer.parseInt(values[2]);
            this.grounder = new Grounder(threads, throttle, apr, prover, program, plugins);
    if (isOn(flags, USE_TRAIN)) {
        this.setupSRW(line, flags, options);
        if (isOn(flags, USE_TRAINER)) {
            // set default stopping criteria
            double percent = StoppingCriterion.DEFAULT_MAX_PCT_IMPROVEMENT;
            int stableEpochs = StoppingCriterion.DEFAULT_MIN_STABLE_EPOCHS;
            TRAINERS type = TRAINERS.cached;
            if (line.hasOption(TRAINER_MODULE_OPTION))
                type = TRAINERS.valueOf(line.getOptionValues(TRAINER_MODULE_OPTION)[0]);
            switch(type) {
                case streaming:
                    this.trainer = new Trainer(this.srw, this.nthreads, this.throttle);
                case caching:
                case cached:
                    boolean shuff = CachingTrainer.DEFAULT_SHUFFLE;
                    if (line.hasOption(TRAINER_MODULE_OPTION)) {
                        for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
                            if (val.startsWith("shuff"))
                                shuff = Boolean.parseBoolean(val.substring(val.indexOf("=") + 1));
                    this.trainer = new CachingTrainer(this.srw, this.nthreads, this.throttle, shuff);
                case adagrad:
                    this.usageOptions(options, allFlags, "Trainer 'adagrad' no longer necessary. Use '--srw adagrad' for adagrad descent method.");
                    this.usageOptions(options, allFlags, "Unrecognized trainer " + line.getOptionValue(TRAINER_MODULE_OPTION));
            if (this.srw instanceof AdaGradSRW)
                // override default
                stableEpochs = 2;
            // now get stopping criteria from command line
            if (line.hasOption(TRAINER_MODULE_OPTION)) {
                for (String val : line.getOptionValues(TRAINER_MODULE_OPTION)) {
                    if (val.startsWith("pct"))
                        percent = Double.parseDouble(val.substring(val.indexOf("=") + 1));
                    else if (val.startsWith("stableEpochs"))
                        stableEpochs = Integer.parseInt(val.substring(val.indexOf("=") + 1));
            this.trainer.setStoppingCriteria(stableEpochs, percent);
    if (isOn(flags, USE_SRW) && this.srw == null)
        this.setupSRW(line, flags, options);
Also used : Tanh1( IdDprProver( DprProver( PathDprProver( PruningIdDprProver( CachingTrainer( Trainer( TracingDfsProver( DfsProver( Grounder( Sigmoid( CachingTrainer( TracingDfsProver( Tanh( PprProver( IdPprProver( PriorityQueueProver( Linear( PruningIdDprProver( ClippedExp( PathDprProver( LReLU( ReLU( IdDprProver( PruningIdDprProver( ClippedExp( Exp( LReLU( IdPprProver(


DprProver ( Prover ( WamProgram ( APROptions ( Query ( StateProofGraph ( Test (org.junit.Test)10 GroundedExample ( InferenceExample ( ProofGraph ( Grounder ( Feature ( File ( IdDprProver ( IdPprProver ( PprProver ( StatusLogger ( WamPlugin ( TracingDfsProver ( CachingTrainer (