Search in sources :

Example 6 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class SRWTest method makeLoss.

public double makeLoss(ParamVector<String, ?> paramVec, PosNegRWExample example) {
    srw.clearLoss();
    srw.accumulateGradient(paramVec, example, new SimpleParamVector<String>(), new StatusLogger());
    return srw.cumulativeLoss().total();
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger)

Example 7 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class SRWTest method makeGradient.

//	 Test removed: We no longer compute rwr in SRW
//	
//	/**
//	 * Uniform weights should be the same as the unparameterized basic RWR
//	 */
//	@Test
//	public void testUniformRWR() {
//		log.debug("Test logging");
//		int maxT = 10;
//		
//		TIntDoubleMap baseLineVec = myRWR(startVec,brGraph,maxT);
//		uniformParams.put("id(restart)",srw.getWeightingScheme().defaultWeight());
//		TIntDoubleMap newVec = srw.rwrUsingFeatures(brGraph, startVec, uniformParams);
//		equalScores(baseLineVec,newVec);
//	}
//	
//	public ParamVector<String,?> makeParams(Map<String,Double> foo) {
//		return new SimpleParamVector(foo);
//	}
//	
//	public ParamVector<String,?> makeParams() {
//		return new SimpleParamVector();
//	}
public ParamVector<String, ?> makeGradient(SRW srw, ParamVector<String, ?> paramVec, TIntDoubleMap query, int[] pos, int[] neg) {
    ParamVector<String, ?> grad = new SimpleParamVector<String>();
    srw.accumulateGradient(paramVec, factory.makeExample("gradient", brGraph, query, pos, neg), grad, new StatusLogger());
    return grad;
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) SimpleParamVector(edu.cmu.ml.proppr.util.math.SimpleParamVector)

Example 8 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class SRWTest method testLearn1.

/**
	 * check that learning on red/blue graph works
	 */
@Test
public void testLearn1() {
    TIntDoubleMap query = new TIntDoubleHashMap();
    query.put(nodes.getId("r0"), 1.0);
    int[] pos = new int[blues.size()];
    {
        int i = 0;
        for (String k : blues) pos[i++] = nodes.getId(k);
    }
    int[] neg = new int[reds.size()];
    {
        int i = 0;
        for (String k : reds) neg[i++] = nodes.getId(k);
    }
    PosNegRWExample example = factory.makeExample("learn1", brGraph, query, pos, neg);
    //		ParamVector weightVec = new SimpleParamVector();
    //		weightVec.put("fromb",1.01);
    //		weightVec.put("tob",1.0);
    //		weightVec.put("fromr",1.03);
    //		weightVec.put("tor",1.0);
    //		weightVec.put("id(restart)",1.02);
    ParamVector<String, ?> trainedParams = uniformParams.copy();
    double preLoss = makeLoss(trainedParams, example);
    srw.clearLoss();
    srw.trainOnExample(trainedParams, example, new StatusLogger());
    double postLoss = makeLoss(trainedParams, example);
    assertTrue(String.format("preloss %f >=? postloss %f", preLoss, postLoss), preLoss == 0 || preLoss > postLoss);
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) Test(org.junit.Test)

Example 9 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class DuplicateSignatureRuleTest method test2.

@Test
public void test2() throws LogicProgramException, IOException {
    APROptions apr = new APROptions("depth=10");
    WamProgram program = WamBaseProgram.load(new File(PROGRAM));
    ProofGraph pg = new StateProofGraph(Query.parse("canExit(steve,X)"), apr, program);
    Prover p = new TracingDfsProver(apr);
    Map<Query, Double> result = p.solvedQueries(pg, new StatusLogger());
    for (Map.Entry<Query, Double> e : result.entrySet()) {
        System.out.println(e.getValue() + "\t" + e.getKey());
        assertEquals("Steve not allowed to exit " + e.getKey() + "\n", "canExit(steve,kitchen).", e.getKey().toString());
    }
}
Also used : TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) Prover(edu.cmu.ml.proppr.prove.Prover) TracingDfsProver(edu.cmu.ml.proppr.prove.TracingDfsProver) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) APROptions(edu.cmu.ml.proppr.util.APROptions) File(java.io.File) Map(java.util.Map) Test(org.junit.Test)

Example 10 with StatusLogger

use of edu.cmu.ml.proppr.util.StatusLogger in project ProPPR by TeamCohen.

the class PathDprProverTest method test.

@Test
public void test() throws IOException, LogicProgramException {
    APROptions apr = new APROptions();
    apr.epsilon = 1e-5;
    apr.alpha = 0.01;
    WamProgram program = WamBaseProgram.load(new File(RULES));
    WamPlugin[] plugins = new WamPlugin[] { SparseGraphPlugin.load(apr, new File(SparseGraphPluginTest.PLUGIN)) };
    PathDprProver p = new PathDprProver(apr);
    Query query = Query.parse("kids(bette,Y)");
    StateProofGraph pg = new StateProofGraph(query, apr, program, plugins);
    p.prove(pg, new StatusLogger());
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) WamPlugin(edu.cmu.ml.proppr.prove.wam.plugins.WamPlugin) Query(edu.cmu.ml.proppr.prove.wam.Query) WamProgram(edu.cmu.ml.proppr.prove.wam.WamProgram) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) APROptions(edu.cmu.ml.proppr.util.APROptions) File(java.io.File) Test(org.junit.Test) GrounderTest(edu.cmu.ml.proppr.GrounderTest) SparseGraphPluginTest(edu.cmu.ml.proppr.prove.wam.plugins.SparseGraphPluginTest)

Aggregations

StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)23 Test (org.junit.Test)13 Query (edu.cmu.ml.proppr.prove.wam.Query)11 StateProofGraph (edu.cmu.ml.proppr.prove.wam.StateProofGraph)11 File (java.io.File)10 WamProgram (edu.cmu.ml.proppr.prove.wam.WamProgram)9 Prover (edu.cmu.ml.proppr.prove.Prover)8 APROptions (edu.cmu.ml.proppr.util.APROptions)8 ProofGraph (edu.cmu.ml.proppr.prove.wam.ProofGraph)6 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)4 DprProver (edu.cmu.ml.proppr.prove.DprProver)4 State (edu.cmu.ml.proppr.prove.wam.State)4 Map (java.util.Map)4 DfsProver (edu.cmu.ml.proppr.prove.DfsProver)3 ConstantArgument (edu.cmu.ml.proppr.prove.wam.ConstantArgument)3 Goal (edu.cmu.ml.proppr.prove.wam.Goal)3 WamPlugin (edu.cmu.ml.proppr.prove.wam.plugins.WamPlugin)3 ArrayList (java.util.ArrayList)3 GrounderTest (edu.cmu.ml.proppr.GrounderTest)2 InferenceExample (edu.cmu.ml.proppr.examples.InferenceExample)2