use of edu.cmu.ml.proppr.prove.wam.Outlink in project ProPPR by TeamCohen.
the class FactsPlugin method outlinks.
@Override
public List<Outlink> outlinks(State state, WamInterpreter wamInterp, boolean computeFeatures) throws LogicProgramException {
List<Outlink> result = new LinkedList<Outlink>();
String jumpTo = state.getJumpTo();
int delim = jumpTo.indexOf(WamInterpreter.JUMPTO_DELIMITER);
int arity = Integer.parseInt(jumpTo.substring(delim + 1));
boolean returnWeights = jumpTo.substring(0, delim).endsWith(WamPlugin.WEIGHTED_SUFFIX);
if (returnWeights)
jumpTo = unweightedJumpto(state.getJumpTo());
String[] argConst = new String[arity];
for (int i = 0; i < arity; i++) argConst[i] = wamInterp.getConstantArg(arity, i + 1);
if (returnWeights && argConst[arity - 1] != null) {
throw new LogicProgramException("predicate " + state.getJumpTo() + " called with bound last argument!");
}
if (log.isDebugEnabled())
log.debug("Fetching outlinks for " + jumpTo + ": " + Dictionary.buildString(argConst, new StringBuilder(), ", ").toString());
List<WeightedArgs> values = null;
// fill values according to the query
if (argConst[0] == null && (argConst.length == 1 || argConst[1] == null)) {
values = indexJ.get(jumpTo);
} else if (argConst[0] != null && (argConst.length == 1 || argConst[1] == null)) {
values = indexJA1.get(new JumpArgKey(jumpTo, argConst[0]));
} else if (argConst[0] == null && argConst.length > 1 && argConst[1] != null) {
values = indexJA2.get(new JumpArgKey(jumpTo, argConst[1]));
} else if (argConst.length > 1 && argConst[0] != null && argConst[1] != null) {
if (useTernaryIndex) {
values = indexJA1A2.get(new JumpArgArgKey(jumpTo, argConst[0], argConst[1]));
} else {
values = indexJA1.get(new JumpArgKey(jumpTo, argConst[0]));
List<WeightedArgs> alternate = indexJA2.get(new JumpArgKey(jumpTo, argConst[1]));
// treat null lists as empty lists here - wwc
if (alternate == null)
alternate = new java.util.ArrayList<WeightedArgs>();
if (values == null)
values = new java.util.ArrayList<WeightedArgs>();
if (values.size() > alternate.size())
values = alternate;
}
} else {
throw new IllegalStateException("Can't happen");
}
// then iterate through what you got
if (values == null)
return result;
for (WeightedArgs val : values) {
if (!check(argConst, val.args, returnWeights))
continue;
wamInterp.restoreState(state);
for (int i = 0; i < argConst.length; i++) {
if (argConst[i] == null) {
if (i < val.args.length) {
wamInterp.setArg(arity, i + 1, val.args[i]);
} else if (returnWeights) {
log.debug("Using facts weight " + val.wt);
wamInterp.setWt(arity, i + 1, val.wt);
}
}
}
wamInterp.returnp();
wamInterp.executeWithoutBranching();
if (computeFeatures) {
result.add(new Outlink(scaleFD(this.fd, val.wt), wamInterp.saveState()));
} else {
result.add(new Outlink(null, wamInterp.saveState()));
}
}
return result;
}
use of edu.cmu.ml.proppr.prove.wam.Outlink in project ProPPR by TeamCohen.
the class DfsProver method dfs.
/**
* Do depth first search from a state, yielding all states in the tree,
together with the incoming weights.
* @throws LogicProgramException */
protected void dfs(StateProofGraph pg, State s, int depth, double incomingEdgeWeight, List<Entry> tail) throws LogicProgramException {
beforeDfs(s, pg, depth);
Entry e = new Entry(s, incomingEdgeWeight);
tail.add(e);
if (!s.isCompleted() && depth < this.apr.maxDepth) {
backtrace.push(s);
List<Outlink> outlinks = pg.pgOutlinks(s, trueLoop);
if (outlinks.size() == 0)
if (log.isDebugEnabled())
log.debug("exit dfs: no outlinks for " + s);
double z = 0;
for (Outlink o : outlinks) {
o.wt = this.weighter.w(o.fd);
z += o.wt;
}
for (Outlink o : outlinks) {
//skip resets
if (o.child.equals(pg.getStartState()))
continue;
//recurse into non-resets
e.w -= o.wt / z;
dfs(pg, o.child, depth + 1, o.wt / z, tail);
}
backtrace.pop(s);
} else if (log.isDebugEnabled())
log.debug("exit dfs: " + (s.isCompleted() ? "state completed" : ("depth " + depth + ">" + this.apr.maxDepth)));
}
use of edu.cmu.ml.proppr.prove.wam.Outlink in project ProPPR by TeamCohen.
the class DprProver method proveState.
protected int proveState(StateProofGraph pg, Map<State, Double> p, Map<State, Double> r, State u, int pushCounter, int depth, double iterEpsilon, StatusLogger status) {
if (this.maxTreeDepth > 0 && depth > this.maxTreeDepth) {
if (log.isDebugEnabled())
log.debug(String.format("Rejecting eps %f @depth %d > %d ru %.6f deg %d state %s", iterEpsilon, depth, this.maxTreeDepth, r.get(u), -1, u));
return pushCounter;
}
try {
int deg = pg.pgDegree(u);
if (r.get(u) / deg > iterEpsilon) {
backtrace.push(u);
pushCounter += 1;
try {
List<Outlink> outs = pg.pgOutlinks(u, TRUELOOP_ON);
double z = 0.0;
for (Outlink o : outs) {
o.wt = this.weighter.w(o.fd);
if (Double.isInfinite(o.wt) || Double.isNaN(o.wt))
log.warn("Illegal weight (" + Double.toString(o.wt) + ") at outlink " + o.child + ";" + Dictionary.buildString(o.fd, new StringBuilder(), "\n\t").toString());
z += o.wt;
}
if (z == 0) {
//then we're in trouble
log.warn("Illegal graph: weight on this node has nowhere to go");
for (Outlink o : outs) {
log.warn("Outlink: " + Dictionary.buildString(o.fd, new StringBuilder(), "; "));
}
}
// push this state as far as you can
while (r.get(u) / deg > iterEpsilon) {
double ru = r.get(u);
if (log.isDebugEnabled())
log.debug(String.format("Pushing eps %f @depth %d ru %.6f deg %d z %.6f state %s", iterEpsilon, depth, ru, deg, z, u));
else if (log.isInfoEnabled() && status.due(2))
log.info(String.format("Pushing eps %f @depth %d ru %.6f deg %d z %.6f state %s", iterEpsilon, depth, ru, deg, z, u));
// p[u] += alpha * ru
addToP(p, u, ru);
// r[u] *= (1-alpha) * stay?
r.put(u, (1.0 - apr.alpha) * stayProbability * ru);
// for each v near u:
for (Outlink o : outs) {
// skip 0-weighted links
if (o.wt == 0)
continue;
// r[v] += (1-alpha) * move? * Muv * ru
Dictionary.increment(r, o.child, (1.0 - apr.alpha) * moveProbability * (o.wt / z) * ru, "(elided)");
}
if (log.isDebugEnabled()) {
// sanity-check r:
double sumr = 0;
for (Double d : r.values()) {
sumr += d;
}
double sump = 0;
for (Double d : p.values()) {
sump += d;
}
if (Math.abs(sump + sumr - 1.0) > apr.epsilon) {
log.debug("Should be 1.0 but isn't: after push sum p + r = " + sump + " + " + sumr + " = " + (sump + sumr));
}
}
}
// for each v near u:
for (Outlink o : outs) {
// on the next for loop iter is passed down again...
if (o.child.equals(pg.getStartState()))
continue;
if (o.wt == 0)
continue;
pushCounter = this.proveState(pg, p, r, o.child, pushCounter, depth + 1, iterEpsilon, status);
}
} catch (LogicProgramException e) {
backtrace.rethrow(e);
}
backtrace.pop(u);
} else {
if (log.isDebugEnabled())
log.debug(String.format("Rejecting eps %f @depth %d ru %.6f deg %d state %s", iterEpsilon, depth, r.get(u), deg, u));
}
} catch (LogicProgramException e) {
this.backtrace.rethrow(e);
}
return pushCounter;
}
use of edu.cmu.ml.proppr.prove.wam.Outlink in project ProPPR by TeamCohen.
the class LightweightStateGraph method setOutlinks.
public void setOutlinks(State u, List<Outlink> outlinks) {
// wwc: why are we saving these outlinks as a trove thing? space?
int ui = this.nodeTab.getId(u);
if (near.containsKey(ui)) {
log.warn("Overwriting previous outlinks for state " + u);
edgeCount -= near.get(ui).size();
}
TIntArrayList nearui = new TIntArrayList(outlinks.size());
near.put(ui, nearui);
TIntObjectHashMap<TIntDoubleHashMap> fui = new TIntObjectHashMap<TIntDoubleHashMap>();
edgeFeatureDict.put(ui, fui);
for (Outlink o : outlinks) {
int vi = this.nodeTab.getId(o.child);
nearui.add(vi);
edgeCount++;
TIntDoubleHashMap fvui = new TIntDoubleHashMap(o.fd.size());
for (Map.Entry<Feature, Double> e : o.fd.entrySet()) {
fvui.put(this.featureTab.getId(e.getKey()), e.getValue());
}
fui.put(vi, fvui);
}
}
use of edu.cmu.ml.proppr.prove.wam.Outlink in project ProPPR by TeamCohen.
the class WamInterpreterTest method outlinks.
public List<Outlink> outlinks(WamInterpreter wamInterp) {
List<Outlink> result = new ArrayList<Outlink>();
if (!wamInterp.getState().isCompleted() && !wamInterp.getState().isFailed()) {
assertTrue("not at a call", wamInterp.getState().getJumpTo() != null);
assertTrue("no definition for " + wamInterp.getState().getJumpTo(), wamInterp.getProgram().hasLabel(wamInterp.getState().getJumpTo()));
State savedState = wamInterp.saveState();
for (Integer addr : wamInterp.getProgram().getAddresses(wamInterp.getState().getJumpTo())) {
wamInterp.restoreState(savedState);
//try and match the rule head
Map<Feature, Double> features = wamInterp.executeWithoutBranching(addr);
if (!features.isEmpty() && !wamInterp.getState().isFailed()) {
wamInterp.executeWithoutBranching();
if (!wamInterp.getState().isFailed()) {
result.add(new Outlink(features, wamInterp.getState()));
}
}
}
}
return result;
}
Aggregations