Search in sources :

Example 1 with DprExample

use of edu.cmu.ml.proppr.examples.DprExample in project ProPPR by TeamCohen.

the class DprSRW method inference.

@Override
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
    DprExample ex = (DprExample) example;
    // startNode maps node->weight
    TIntDoubleMap query = ex.getQueryVec();
    if (query.size() > 1)
        throw new UnsupportedOperationException("Can't do multi-node queries");
    // maps storing the probability and remainder weights of the nodes:
    ex.p = new double[ex.getGraph().node_hi];
    ex.r = new double[ex.getGraph().node_hi];
    // initializing the above maps:
    Arrays.fill(ex.p, 0.0);
    Arrays.fill(ex.r, 0.0);
    for (TIntDoubleIterator it = query.iterator(); it.hasNext(); ) {
        it.advance();
        ex.r[it.key()] = it.value();
    }
    // maps storing the gradients of p and r for each node:
    ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
    ex.dr = new TIntDoubleMap[ex.getGraph().node_hi];
    // initializing the above maps:
    //		for(int node : example.getGraph().getNodes()) {
    //			dp.put(node, new TObjectDoubleHashMap<String>());
    //			dr.put(node, new TObjectDoubleHashMap<String>());
    //			for(String feature : (example.getGraph().getFeatureSet()))
    //			{
    //				dp.get(node).put(feature, 0.0);
    //				dr.get(node).put(feature, 0.0);
    //			}
    //		}
    // APR Algorithm:
    int completeCount = 0;
    while (completeCount < ex.getGraph().node_hi) {
        if (log.isDebugEnabled())
            log.debug("Starting pass");
        completeCount = 0;
        for (int u = 0; u < ex.getGraph().node_hi; u++) {
            double ru = ex.r[u];
            int udeg = ex.getGraph().node_near_hi[u] - ex.getGraph().node_near_lo[u];
            if (ru / (double) udeg > c.apr.epsilon)
                while (ru / udeg > c.apr.epsilon) {
                    this.push(u, params, ex);
                    if (ex.r[u] > ru)
                        throw new IllegalStateException("r increasing! :(");
                    ru = ex.r[u];
                }
            else {
                completeCount++;
                if (log.isDebugEnabled())
                    log.debug("Counting " + u);
            }
        }
        if (log.isDebugEnabled())
            log.debug(completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
        else if (log.isInfoEnabled() && status.due(3))
            log.info(Thread.currentThread() + " inference: " + completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
    }
//		GradientComponents g = new GradientComponents();
//		g.p = p;
//		g.d = dp;
//		return g;
}
Also used : DprExample(edu.cmu.ml.proppr.examples.DprExample) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Aggregations

DprExample (edu.cmu.ml.proppr.examples.DprExample)1 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)1 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)1