use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class InnerProductWeighterTest method test.
@Test
public void test() {
BasicConfigurator.configure();
Logger.getRootLogger().setLevel(Level.INFO);
HashMap<Feature, Double> weights = new HashMap<Feature, Double>();
weights.put(new Feature("feathers"), 0.5);
weights.put(new Feature("scales"), 0.3);
weights.put(new Feature("fur"), 0.7);
FeatureDictWeighter w = new InnerProductWeighter(weights);
Feature ng = new Feature("hair");
HashMap<Feature, Double> featureDict = new HashMap<Feature, Double>();
featureDict.put(ng, 0.9);
featureDict.putAll(weights);
assertFalse("Should start empty!", w.unknownFeatures.contains(ng));
for (Map.Entry<Feature, Double> e : featureDict.entrySet()) {
e.setValue(e.getValue() - Math.random() / 10);
}
w.w(featureDict);
assertTrue("Wasn't added!", w.unknownFeatures.contains(ng));
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class ProverTestTemplate method testProveState.
@Test
public void testProveState() throws LogicProgramException {
log.info("testProveState");
FeatureDictWeighter w = new InnerProductWeighter();
SymbolTable<Feature> featureTab = new SimpleSymbolTable<Feature>();
int milk = featureTab.getId(new Feature("milk"));
w.put(featureTab.getSymbol(milk), 2);
prover.setWeighter(w);
ProofGraph pg = prover.makeProofGraph(new InferenceExample(Query.parse("isa(elsie,X)"), null, null), apr, featureTab, lpMilk, fMilk);
//("isa","elsie","X"));
Map<State, Double> dist = prover.prove(pg, new StatusLogger());
double query = 0.0;
double platypus = 0.0;
double others = 0.0;
double all = 0.0;
for (Map.Entry<State, Double> s : dist.entrySet()) {
Query q = pg.fill(s.getKey());
String arg2 = q.getRhs()[0].getArg(1).getName();
if ("platypus".equals(arg2)) {
platypus = Math.max(platypus, s.getValue());
} else if ("X1".equals(arg2)) {
query = Math.max(query, s.getValue());
} else {
others = Math.max(others, s.getValue());
}
System.out.println(q + "\t" + s.getValue());
all += s.getValue();
}
System.out.println();
System.out.println("query weight: " + query);
System.out.println("platypus weight: " + platypus);
System.out.println("others weight: " + others);
// assertTrue("query should retain most weight",query > Math.max(platypus,others));
assertTrue("milk-featured paths should score higher than others", platypus > others);
assertEquals("Total weight of all states should be around 1.0", 1.0, all, 10 * this.apr.epsilon);
assertEquals("Known features", 1, prover.weighter.numKnownFeatures);
assertEquals("Unknown features", 5, prover.weighter.numUnknownFeatures);
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class QueryAnswerer method addParams.
public void addParams(Prover<P> prover, ParamVector<String, ?> params, SquashingFunction<Goal> f) {
InnerProductWeighter w = InnerProductWeighter.fromParamVec(params, f);
prover.setWeighter(w);
for (Feature g : w.getWeights().keySet()) this.featureTable.insert(g);
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class LightweightStateGraph method getOutlinks.
/** Return the neighbors of node u. */
public List<Outlink> getOutlinks(State u) {
// wwc: why do we need to recompute these each time?
List<Outlink> result = new ArrayList<Outlink>();
for (State v : near(u)) {
Map<Feature, Double> fd = getFeatures(u, v);
result.add(new Outlink(fd, v));
}
return result;
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class LightweightStateGraph method getFeatures.
public Map<Feature, Double> getFeatures(State u, State v) {
int ui = this.nodeTab.getId(u), vi = this.nodeTab.getId(v);
if (!edgeFeatureDict.containsKey(ui))
return DEFAULT_FD;
TIntObjectHashMap<TIntDoubleHashMap> fu = edgeFeatureDict.get(ui);
if (!fu.containsKey(vi))
return DEFAULT_FD;
TIntDoubleHashMap fuvi = fu.get(vi);
final HashMap<Feature, Double> ret = new HashMap<Feature, Double>();
fuvi.forEachEntry(new TIntDoubleProcedure() {
@Override
public boolean execute(int fi, double wt) {
ret.put(featureTab.getSymbol(fi), wt);
return true;
}
});
return ret;
}
Aggregations