use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class InferenceGraphTestTemplate method test.
@Test
public void test() {
InferenceGraph g = getGraph();
MutableState a = new MutableState();
a.setJumpTo("foo");
a.setCanonicalHash(2);
a.setCanonicalForm("a");
MutableState b = new MutableState();
b.setJumpTo("bar");
b.setCanonicalHash(2);
b.setCanonicalForm("b");
MutableState b2 = new MutableState();
b2.setJumpTo("bar");
b2.setCanonicalHash(2);
b2.setCanonicalForm("b");
MutableState c = new MutableState();
c.setJumpTo("baz");
c.setCanonicalHash(3);
c.setCanonicalForm("c");
Map<Feature, Double> fd = new HashMap<Feature, Double>();
fd.put(new Feature("quite"), 1.0);
// a -> b
List<Outlink> outlinks = new ArrayList<Outlink>();
outlinks.add(new Outlink(fd, b));
g.setOutlinks(g.getId(a), outlinks);
// c -> b2 (=b)
outlinks = new ArrayList<Outlink>();
outlinks.add(new Outlink(fd, b2));
g.setOutlinks(g.getId(c), outlinks);
{
String s = g.serialize(true);
String[] parts = s.split("\t");
assertEquals(6, parts.length);
assertEquals("3", parts[0]);
assertEquals("2", parts[1]);
assertEquals("2", parts[2]);
assertEquals("quite", parts[3]);
String[] edges = new String[] { parts[4], parts[5] };
Arrays.sort(edges);
assertEquals("1->2:1@1.0", edges[0]);
assertEquals("3->2:1@1.0", edges[1]);
}
{
String s = g.serialize(false);
String[] parts = s.split("\t");
assertEquals(5, parts.length);
assertEquals("3", parts[0]);
assertEquals("2", parts[1]);
assertEquals("2", parts[2]);
String[] edges = new String[] { parts[3], parts[4] };
Arrays.sort(edges);
assertEquals("1->2:1@1.0", edges[0]);
assertEquals("3->2:1@1.0", edges[1]);
}
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class NodeMergingTest method doTest.
private void doTest(File rules, File facts, String squery, String[] spos, String[] sneg, int nodeSize, int posSize) throws LogicProgramException, IOException {
APROptions apr = new APROptions();
Prover p = new DprProver(apr);
WamProgram program = WamBaseProgram.load(rules);
WamPlugin[] plugins = null;
if (facts.getName().endsWith(FactsPlugin.FILE_EXTENSION))
plugins = new WamPlugin[] { FactsPlugin.load(apr, facts, false) };
else if (facts.getName().endsWith(GraphlikePlugin.FILE_EXTENSION))
plugins = new WamPlugin[] { LightweightGraphPlugin.load(apr, facts, -1) };
Grounder grounder = new Grounder(apr, p, program, plugins);
Query query = Query.parse(squery);
Query[] pos = new Query[spos.length];
for (int i = 0; i < spos.length; i++) pos[i] = Query.parse(spos[i]);
Query[] neg = new Query[sneg.length];
for (int i = 0; i < sneg.length; i++) neg[i] = Query.parse(sneg[i]);
ProofGraph pg = new StateProofGraph(new InferenceExample(query, pos, neg), apr, new SimpleSymbolTable<Feature>(), program, plugins);
GroundedExample ex = grounder.groundExample(pg);
System.out.println(grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n"));
if (nodeSize >= 0)
assertEquals("improper node duplication", nodeSize, ex.getGraph().nodeSize());
if (posSize >= 0)
assertEquals("improper # solutions found", posSize, ex.getPosList().size());
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class LightweightStateGraph method serialize.
public String serialize(boolean featureIndex) {
StringBuilder ret = //numNodes
new StringBuilder().append(this.nodeSize()).append("\t").append(this.edgeCount).append(// waiting for label dependency size
"\t");
int labelDependencies = 0;
StringBuilder sb = new StringBuilder();
boolean first = true;
if (featureIndex) {
sb.append("\t");
for (int fi = 1; fi <= this.featureTab.size(); fi++) {
if (!first)
sb.append(LearningGraphBuilder.FEATURE_INDEX_DELIM);
else
first = false;
Feature f = this.featureTab.getSymbol(fi);
sb.append(f);
}
}
// foreach src node
for (TIntObjectIterator<TIntArrayList> it = this.near.iterator(); it.hasNext(); ) {
it.advance();
int ui = it.key();
TIntArrayList nearu = it.value();
HashSet<Integer> outgoingFeatures = new HashSet<Integer>();
//foreach dst from src
for (TIntIterator vit = nearu.iterator(); vit.hasNext(); ) {
int vi = vit.next();
sb.append("\t");
sb.append(ui).append(LearningGraphBuilder.SRC_DST_DELIM).append(vi);
sb.append(LearningGraphBuilder.EDGE_DELIM);
//foreach feature on src,dst
for (TIntDoubleIterator fit = edgeFeatureDict.get(ui).get(vi).iterator(); fit.hasNext(); ) {
fit.advance();
int fi = fit.key();
double wi = fit.value();
outgoingFeatures.add(fi);
sb.append(fi).append(LearningGraphBuilder.FEATURE_WEIGHT_DELIM).append(wi).append(LearningGraphBuilder.EDGE_FEATURE_DELIM);
}
// drop last ','
sb.deleteCharAt(sb.length() - 1);
}
labelDependencies += outgoingFeatures.size() * nearu.size();
}
ret.append(labelDependencies).append(sb);
return ret.toString();
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class LightweightStateGraph method setOutlinks.
public void setOutlinks(State u, List<Outlink> outlinks) {
// wwc: why are we saving these outlinks as a trove thing? space?
int ui = this.nodeTab.getId(u);
if (near.containsKey(ui)) {
log.warn("Overwriting previous outlinks for state " + u);
edgeCount -= near.get(ui).size();
}
TIntArrayList nearui = new TIntArrayList(outlinks.size());
near.put(ui, nearui);
TIntObjectHashMap<TIntDoubleHashMap> fui = new TIntObjectHashMap<TIntDoubleHashMap>();
edgeFeatureDict.put(ui, fui);
for (Outlink o : outlinks) {
int vi = this.nodeTab.getId(o.child);
nearui.add(vi);
edgeCount++;
TIntDoubleHashMap fvui = new TIntDoubleHashMap(o.fd.size());
for (Map.Entry<Feature, Double> e : o.fd.entrySet()) {
fvui.put(this.featureTab.getId(e.getKey()), e.getValue());
}
fui.put(vi, fvui);
}
}
use of edu.cmu.ml.proppr.prove.wam.Feature in project ProPPR by TeamCohen.
the class WeightedRulesTest method test.
@Test
public void test() throws IOException, LogicProgramException {
APROptions apr = new APROptions();
Prover p = new DprProver(apr);
WamProgram program = WamBaseProgram.load(RULES);
WamPlugin[] plugins = new WamPlugin[] { FactsPlugin.load(apr, FACTS, false) };
Grounder grounder = new Grounder(apr, p, program, plugins);
assertTrue(plugins[0].claim("ruleWeight#/2"));
Query query = Query.parse("shoppingList(X)");
ProofGraph pg = new StateProofGraph(new InferenceExample(query, new Query[] { Query.parse("shoppingList(kidney_beans)") }, new Query[] { Query.parse("shoppingList(cinnamon)") }), apr, new SimpleSymbolTable<Feature>(), program, plugins);
// Map<String,Double> m = p.solutions(pg);
// System.out.println(Dictionary.buildString(m, new StringBuilder(), "\n").toString());
GroundedExample ex = grounder.groundExample(p, pg);
ex.getGraph().serialize();
String serialized = grounder.serializeGroundedExample(pg, ex).replaceAll("\t", "\n");
System.out.println(serialized);
assertTrue("Rule weights must appear in ground graph (2.0)", serialized.indexOf("2.0") >= 0);
assertTrue("Rule weights must appear in ground graph (3.0)", serialized.indexOf("3.0") >= 0);
}
Aggregations