Search in sources :

Example 11 with TIntDoubleMap

use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.

the class SRW method gradient.

protected TIntDoubleMap gradient(ParamVector<String, ?> params, PosNegRWExample example) {
    PosNegRWExample ex = (PosNegRWExample) example;
    Set<String> features = this.regularizer.localFeatures(params, ex.getGraph());
    TIntDoubleMap gradient = new TIntDoubleHashMap(features.size());
    // add regularization term
    regularization(params, ex, gradient);
    int nonzero = lossf.computeLossGradient(params, example, gradient, this.cumloss, c);
    for (int i : gradient.keys()) {
        gradient.put(i, gradient.get(i) / example.length());
    }
    if (nonzero == 0) {
        this.zeroGradientData.numZero++;
        if (this.zeroGradientData.numZero < MAX_ZERO_LOGS) {
            this.zeroGradientData.examples.append("\n").append(ex);
        }
    }
    return gradient;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Example 12 with TIntDoubleMap

use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.

the class SRW method sgd.

/** edits params */
protected void sgd(ParamVector<String, ?> params, PosNegRWExample ex) {
    TIntDoubleMap gradient = gradient(params, ex);
    // apply gradient to param vector
    for (TIntDoubleIterator grad = gradient.iterator(); grad.hasNext(); ) {
        grad.advance();
        if (grad.value() == 0)
            continue;
        String feature = ex.getGraph().featureLibrary.getSymbol(grad.key());
        if (trainable(feature)) {
            params.adjustValue(feature, -learningRate(feature) * grad.value());
            if (params.get(feature).isInfinite()) {
                log.warn("Infinity at " + feature + "; gradient " + grad.value());
            }
        }
    }
}
Also used : TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Example 13 with TIntDoubleMap

use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.

the class DprSRW method inference.

@Override
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
    DprExample ex = (DprExample) example;
    // startNode maps node->weight
    TIntDoubleMap query = ex.getQueryVec();
    if (query.size() > 1)
        throw new UnsupportedOperationException("Can't do multi-node queries");
    // maps storing the probability and remainder weights of the nodes:
    ex.p = new double[ex.getGraph().node_hi];
    ex.r = new double[ex.getGraph().node_hi];
    // initializing the above maps:
    Arrays.fill(ex.p, 0.0);
    Arrays.fill(ex.r, 0.0);
    for (TIntDoubleIterator it = query.iterator(); it.hasNext(); ) {
        it.advance();
        ex.r[it.key()] = it.value();
    }
    // maps storing the gradients of p and r for each node:
    ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
    ex.dr = new TIntDoubleMap[ex.getGraph().node_hi];
    // initializing the above maps:
    //		for(int node : example.getGraph().getNodes()) {
    //			dp.put(node, new TObjectDoubleHashMap<String>());
    //			dr.put(node, new TObjectDoubleHashMap<String>());
    //			for(String feature : (example.getGraph().getFeatureSet()))
    //			{
    //				dp.get(node).put(feature, 0.0);
    //				dr.get(node).put(feature, 0.0);
    //			}
    //		}
    // APR Algorithm:
    int completeCount = 0;
    while (completeCount < ex.getGraph().node_hi) {
        if (log.isDebugEnabled())
            log.debug("Starting pass");
        completeCount = 0;
        for (int u = 0; u < ex.getGraph().node_hi; u++) {
            double ru = ex.r[u];
            int udeg = ex.getGraph().node_near_hi[u] - ex.getGraph().node_near_lo[u];
            if (ru / (double) udeg > c.apr.epsilon)
                while (ru / udeg > c.apr.epsilon) {
                    this.push(u, params, ex);
                    if (ex.r[u] > ru)
                        throw new IllegalStateException("r increasing! :(");
                    ru = ex.r[u];
                }
            else {
                completeCount++;
                if (log.isDebugEnabled())
                    log.debug("Counting " + u);
            }
        }
        if (log.isDebugEnabled())
            log.debug(completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
        else if (log.isInfoEnabled() && status.due(3))
            log.info(Thread.currentThread() + " inference: " + completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
    }
//		GradientComponents g = new GradientComponents();
//		g.p = p;
//		g.d = dp;
//		return g;
}
Also used : DprExample(edu.cmu.ml.proppr.examples.DprExample) TIntDoubleMap(gnu.trove.map.TIntDoubleMap) TIntDoubleIterator(gnu.trove.iterator.TIntDoubleIterator)

Example 14 with TIntDoubleMap

use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.

the class DprSRW method push.

/**
	 * Simulates a single lazy random walk step on the input vertex
	 * @param u the vertex to be 'pushed'
	 * @param p
	 * @param r
	 * @param g
	 * @param paramVec
	 * @param dp
	 * @param dr
	 */
public void push(int u, ParamVector<String, ?> paramVec, DprExample ex) {
    log.debug("Pushing " + u);
    // update p for the pushed node:
    ex.p[u] += c.apr.alpha * ex.r[u];
    if (ex.dr[u] == null)
        ex.dr[u] = new TIntDoubleHashMap();
    TIntDoubleMap dru = ex.dr[u];
    TIntDoubleMap unwrappedDotP = new TIntDoubleHashMap();
    for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
        int v = ex.getGraph().edge_dest[eid];
        unwrappedDotP.put(v, dotP(ex.getGraph(), eid, paramVec));
    }
    // calculate the sum of the weights (raised to exp) of the edges adjacent to the input node:
    double rowSum = this.totalEdgeProbWeight(ex.getGraph(), u, paramVec);
    // calculate the gradients of the rowSums (needed for the calculation of the gradient of r):
    TIntDoubleMap drowSums = new TIntDoubleHashMap();
    TIntDoubleMap prevdr = new TIntDoubleHashMap();
    Set<String> exampleFeatures = ex.getGraph().getFeatureSet();
    for (String feature : exampleFeatures) {
        int flid = ex.getGraph().featureLibrary.getId(feature);
        // simultaneously update the dp for the pushed node:
        if (trainable(feature)) {
            if (ex.dp[u] == null)
                ex.dp[u] = new TIntDoubleHashMap();
            Dictionary.increment(ex.dp[u], flid, c.apr.alpha * dru.get(flid));
        }
        double drowSum = 0;
        for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
            int v = ex.getGraph().edge_dest[eid];
            if (hasFeature(ex.getGraph(), eid, flid)) {
                //g.getFeatures(u, v).containsKey(feature)) {
                drowSum += c.squashingFunction.computeDerivative(unwrappedDotP.get(v));
            }
        }
        drowSums.put(flid, drowSum);
        // update dr for the pushed vertex, storing dr temporarily for the calculation of dr for the other vertices:
        prevdr.put(flid, dru.get(flid));
        dru.put(flid, dru.get(flid) * (1 - c.apr.alpha) * stayProb);
    }
    // update dr for other vertices:
    for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
        int v = ex.getGraph().edge_dest[eid];
        double dotP = c.squashingFunction.edgeWeight(unwrappedDotP.get(v));
        double ddotP = c.squashingFunction.computeDerivative(unwrappedDotP.get(v));
        for (String feature : exampleFeatures) {
            int flid = ex.getGraph().featureLibrary.getId(feature);
            int contained = hasFeature(ex.getGraph(), eid, flid) ? 1 : 0;
            if (ex.dr[v] == null)
                ex.dr[v] = new TIntDoubleHashMap();
            double vdr = Dictionary.safeGet(ex.dr[v], flid, 0.0);
            // whoa this is pretty gross.
            vdr += (1 - stayProb) * (1 - c.apr.alpha) * ((prevdr.get(flid) * dotP / rowSum) + (ex.r[u] * ((contained * ddotP * rowSum) - (dotP * drowSums.get(flid))) / (rowSum * rowSum)));
            ex.dr[v].put(flid, vdr);
        }
    }
    // update r for all affected vertices:
    double ru = ex.r[u];
    ex.r[u] = ru * stayProb * (1 - c.apr.alpha);
    for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
        int v = ex.getGraph().edge_dest[eid];
        // calculate edge weight on v:
        double dotP = c.squashingFunction.edgeWeight(unwrappedDotP.get(v));
        ex.r[v] += (1 - stayProb) * (1 - c.apr.alpha) * (dotP / rowSum) * ru;
    }
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Aggregations

TIntDoubleMap (gnu.trove.map.TIntDoubleMap)14 TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)10 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)6 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)2 PprExample (edu.cmu.ml.proppr.examples.PprExample)2 Test (org.junit.Test)2 DprExample (edu.cmu.ml.proppr.examples.DprExample)1 GraphFormatException (edu.cmu.ml.proppr.graph.GraphFormatException)1 LearningGraph (edu.cmu.ml.proppr.graph.LearningGraph)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1 SimpleParamVector (edu.cmu.ml.proppr.util.math.SimpleParamVector)1 TIntDoubleProcedure (gnu.trove.procedure.TIntDoubleProcedure)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1