use of edu.cmu.ml.proppr.examples.DprExample in project ProPPR by TeamCohen.
the class DprSRW method inference.
@Override
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
DprExample ex = (DprExample) example;
// startNode maps node->weight
TIntDoubleMap query = ex.getQueryVec();
if (query.size() > 1)
throw new UnsupportedOperationException("Can't do multi-node queries");
// maps storing the probability and remainder weights of the nodes:
ex.p = new double[ex.getGraph().node_hi];
ex.r = new double[ex.getGraph().node_hi];
// initializing the above maps:
Arrays.fill(ex.p, 0.0);
Arrays.fill(ex.r, 0.0);
for (TIntDoubleIterator it = query.iterator(); it.hasNext(); ) {
it.advance();
ex.r[it.key()] = it.value();
}
// maps storing the gradients of p and r for each node:
ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
ex.dr = new TIntDoubleMap[ex.getGraph().node_hi];
// initializing the above maps:
// for(int node : example.getGraph().getNodes()) {
// dp.put(node, new TObjectDoubleHashMap<String>());
// dr.put(node, new TObjectDoubleHashMap<String>());
// for(String feature : (example.getGraph().getFeatureSet()))
// {
// dp.get(node).put(feature, 0.0);
// dr.get(node).put(feature, 0.0);
// }
// }
// APR Algorithm:
int completeCount = 0;
while (completeCount < ex.getGraph().node_hi) {
if (log.isDebugEnabled())
log.debug("Starting pass");
completeCount = 0;
for (int u = 0; u < ex.getGraph().node_hi; u++) {
double ru = ex.r[u];
int udeg = ex.getGraph().node_near_hi[u] - ex.getGraph().node_near_lo[u];
if (ru / (double) udeg > c.apr.epsilon)
while (ru / udeg > c.apr.epsilon) {
this.push(u, params, ex);
if (ex.r[u] > ru)
throw new IllegalStateException("r increasing! :(");
ru = ex.r[u];
}
else {
completeCount++;
if (log.isDebugEnabled())
log.debug("Counting " + u);
}
}
if (log.isDebugEnabled())
log.debug(completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
else if (log.isInfoEnabled() && status.due(3))
log.info(Thread.currentThread() + " inference: " + completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
}
// GradientComponents g = new GradientComponents();
// g.p = p;
// g.d = dp;
// return g;
}
Aggregations