Search in sources :

Example 1 with InnerProductWeighter

use of edu.cmu.ml.proppr.prove.InnerProductWeighter in project ProPPR by TeamCohen.

the class QueryAnswerer method main.

public static void main(String[] args) throws IOException {
    try {
        int inputFiles = Configuration.USE_QUERIES | Configuration.USE_PARAMS;
        int outputFiles = Configuration.USE_ANSWERS;
        int modules = Configuration.USE_PROVER | Configuration.USE_SQUASHFUNCTION;
        int constants = Configuration.USE_WAM | Configuration.USE_THREADS | Configuration.USE_ORDER;
        QueryAnswererConfiguration c = new QueryAnswererConfiguration(args, inputFiles, outputFiles, constants, modules);
        //			c.squashingFunction = new Exp();
        System.out.println(c.toString());
        QueryAnswerer qa = new QueryAnswerer(c.apr, c.program, c.plugins, c.prover, c.normalize, c.nthreads, c.topk);
        if (log.isInfoEnabled())
            log.info("Running queries from " + c.queryFile + "; saving results to " + c.solutionsFile);
        if (c.paramsFile != null) {
            ParamsFile file = new ParamsFile(c.paramsFile);
            qa.addParams(c.prover, new SimpleParamVector<String>(Dictionary.load(file, new ConcurrentHashMap<String, Double>())), c.squashingFunction);
            file.check(c);
        }
        long start = System.currentTimeMillis();
        qa.findSolutions(c.queryFile, c.solutionsFile, c.maintainOrder);
        if (c.prover.getWeighter() instanceof InnerProductWeighter) {
            InnerProductWeighter w = (InnerProductWeighter) c.prover.getWeighter();
            int n = w.getWeights().size();
            int m = w.seenKnownFeatures() + w.seenUnknownFeatures();
            if (((double) w.seenKnownFeatures() / n) < MIN_FEATURE_TRANSFER)
                log.warn("Only saw " + w.seenKnownFeatures() + " of " + n + " known features (" + ((double) w.seenKnownFeatures() / n * 100) + "%) -- test data may be too different from training data");
            if (w.seenUnknownFeatures() > w.seenKnownFeatures())
                log.warn("Saw more unknown features (" + w.seenUnknownFeatures() + ") than known features (" + w.seenKnownFeatures() + ") -- test data may be too different from training data");
        }
        System.out.println("Query-answering time: " + (System.currentTimeMillis() - start));
    } catch (Throwable t) {
        t.printStackTrace();
        System.exit(-1);
    }
}
Also used : ParamsFile(edu.cmu.ml.proppr.util.ParamsFile) InnerProductWeighter(edu.cmu.ml.proppr.prove.InnerProductWeighter)

Example 2 with InnerProductWeighter

use of edu.cmu.ml.proppr.prove.InnerProductWeighter in project ProPPR by TeamCohen.

the class InnerProductWeighterTest method test.

@Test
public void test() {
    BasicConfigurator.configure();
    Logger.getRootLogger().setLevel(Level.INFO);
    HashMap<Feature, Double> weights = new HashMap<Feature, Double>();
    weights.put(new Feature("feathers"), 0.5);
    weights.put(new Feature("scales"), 0.3);
    weights.put(new Feature("fur"), 0.7);
    FeatureDictWeighter w = new InnerProductWeighter(weights);
    Feature ng = new Feature("hair");
    HashMap<Feature, Double> featureDict = new HashMap<Feature, Double>();
    featureDict.put(ng, 0.9);
    featureDict.putAll(weights);
    assertFalse("Should start empty!", w.unknownFeatures.contains(ng));
    for (Map.Entry<Feature, Double> e : featureDict.entrySet()) {
        e.setValue(e.getValue() - Math.random() / 10);
    }
    w.w(featureDict);
    assertTrue("Wasn't added!", w.unknownFeatures.contains(ng));
}
Also used : HashMap(java.util.HashMap) Feature(edu.cmu.ml.proppr.prove.wam.Feature) Map(java.util.Map) HashMap(java.util.HashMap) InnerProductWeighter(edu.cmu.ml.proppr.prove.InnerProductWeighter) Test(org.junit.Test)

Example 3 with InnerProductWeighter

use of edu.cmu.ml.proppr.prove.InnerProductWeighter in project ProPPR by TeamCohen.

the class ProverTestTemplate method testProveState.

@Test
public void testProveState() throws LogicProgramException {
    log.info("testProveState");
    FeatureDictWeighter w = new InnerProductWeighter();
    SymbolTable<Feature> featureTab = new SimpleSymbolTable<Feature>();
    int milk = featureTab.getId(new Feature("milk"));
    w.put(featureTab.getSymbol(milk), 2);
    prover.setWeighter(w);
    ProofGraph pg = prover.makeProofGraph(new InferenceExample(Query.parse("isa(elsie,X)"), null, null), apr, featureTab, lpMilk, fMilk);
    //("isa","elsie","X"));
    Map<State, Double> dist = prover.prove(pg, new StatusLogger());
    double query = 0.0;
    double platypus = 0.0;
    double others = 0.0;
    double all = 0.0;
    for (Map.Entry<State, Double> s : dist.entrySet()) {
        Query q = pg.fill(s.getKey());
        String arg2 = q.getRhs()[0].getArg(1).getName();
        if ("platypus".equals(arg2)) {
            platypus = Math.max(platypus, s.getValue());
        } else if ("X1".equals(arg2)) {
            query = Math.max(query, s.getValue());
        } else {
            others = Math.max(others, s.getValue());
        }
        System.out.println(q + "\t" + s.getValue());
        all += s.getValue();
    }
    System.out.println();
    System.out.println("query    weight: " + query);
    System.out.println("platypus weight: " + platypus);
    System.out.println("others   weight: " + others);
    //		assertTrue("query should retain most weight",query > Math.max(platypus,others));
    assertTrue("milk-featured paths should score higher than others", platypus > others);
    assertEquals("Total weight of all states should be around 1.0", 1.0, all, 10 * this.apr.epsilon);
    assertEquals("Known features", 1, prover.weighter.numKnownFeatures);
    assertEquals("Unknown features", 5, prover.weighter.numUnknownFeatures);
}
Also used : StatusLogger(edu.cmu.ml.proppr.util.StatusLogger) Query(edu.cmu.ml.proppr.prove.wam.Query) StateProofGraph(edu.cmu.ml.proppr.prove.wam.StateProofGraph) ProofGraph(edu.cmu.ml.proppr.prove.wam.ProofGraph) Feature(edu.cmu.ml.proppr.prove.wam.Feature) InferenceExample(edu.cmu.ml.proppr.examples.InferenceExample) FeatureDictWeighter(edu.cmu.ml.proppr.prove.FeatureDictWeighter) SimpleSymbolTable(edu.cmu.ml.proppr.util.SimpleSymbolTable) State(edu.cmu.ml.proppr.prove.wam.State) Map(java.util.Map) InnerProductWeighter(edu.cmu.ml.proppr.prove.InnerProductWeighter) Test(org.junit.Test)

Example 4 with InnerProductWeighter

use of edu.cmu.ml.proppr.prove.InnerProductWeighter in project ProPPR by TeamCohen.

the class QueryAnswerer method addParams.

public void addParams(Prover<P> prover, ParamVector<String, ?> params, SquashingFunction<Goal> f) {
    InnerProductWeighter w = InnerProductWeighter.fromParamVec(params, f);
    prover.setWeighter(w);
    for (Feature g : w.getWeights().keySet()) this.featureTable.insert(g);
}
Also used : Feature(edu.cmu.ml.proppr.prove.wam.Feature) InnerProductWeighter(edu.cmu.ml.proppr.prove.InnerProductWeighter)

Aggregations

InnerProductWeighter (edu.cmu.ml.proppr.prove.InnerProductWeighter)4 Feature (edu.cmu.ml.proppr.prove.wam.Feature)3 Map (java.util.Map)2 Test (org.junit.Test)2 InferenceExample (edu.cmu.ml.proppr.examples.InferenceExample)1 FeatureDictWeighter (edu.cmu.ml.proppr.prove.FeatureDictWeighter)1 ProofGraph (edu.cmu.ml.proppr.prove.wam.ProofGraph)1 Query (edu.cmu.ml.proppr.prove.wam.Query)1 State (edu.cmu.ml.proppr.prove.wam.State)1 StateProofGraph (edu.cmu.ml.proppr.prove.wam.StateProofGraph)1 ParamsFile (edu.cmu.ml.proppr.util.ParamsFile)1 SimpleSymbolTable (edu.cmu.ml.proppr.util.SimpleSymbolTable)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1 HashMap (java.util.HashMap)1