Search in sources :

Example 16 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class SRW method gradient.

protected TIntDoubleMap gradient(ParamVector<String, ?> params, PosNegRWExample example) {
    PosNegRWExample ex = (PosNegRWExample) example;
    Set<String> features = this.regularizer.localFeatures(params, ex.getGraph());
    TIntDoubleMap gradient = new TIntDoubleHashMap(features.size());
    // add regularization term
    regularization(params, ex, gradient);
    int nonzero = lossf.computeLossGradient(params, example, gradient, this.cumloss, c);
    for (int i : gradient.keys()) {
        gradient.put(i, gradient.get(i) / example.length());
    }
    if (nonzero == 0) {
        this.zeroGradientData.numZero++;
        if (this.zeroGradientData.numZero < MAX_ZERO_LOGS) {
            this.zeroGradientData.examples.append("\n").append(ex);
        }
    }
    return gradient;
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) PosNegRWExample(edu.cmu.ml.proppr.examples.PosNegRWExample) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Example 17 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class DprSRW method push.

/**
	 * Simulates a single lazy random walk step on the input vertex
	 * @param u the vertex to be 'pushed'
	 * @param p
	 * @param r
	 * @param g
	 * @param paramVec
	 * @param dp
	 * @param dr
	 */
public void push(int u, ParamVector<String, ?> paramVec, DprExample ex) {
    log.debug("Pushing " + u);
    // update p for the pushed node:
    ex.p[u] += c.apr.alpha * ex.r[u];
    if (ex.dr[u] == null)
        ex.dr[u] = new TIntDoubleHashMap();
    TIntDoubleMap dru = ex.dr[u];
    TIntDoubleMap unwrappedDotP = new TIntDoubleHashMap();
    for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
        int v = ex.getGraph().edge_dest[eid];
        unwrappedDotP.put(v, dotP(ex.getGraph(), eid, paramVec));
    }
    // calculate the sum of the weights (raised to exp) of the edges adjacent to the input node:
    double rowSum = this.totalEdgeProbWeight(ex.getGraph(), u, paramVec);
    // calculate the gradients of the rowSums (needed for the calculation of the gradient of r):
    TIntDoubleMap drowSums = new TIntDoubleHashMap();
    TIntDoubleMap prevdr = new TIntDoubleHashMap();
    Set<String> exampleFeatures = ex.getGraph().getFeatureSet();
    for (String feature : exampleFeatures) {
        int flid = ex.getGraph().featureLibrary.getId(feature);
        // simultaneously update the dp for the pushed node:
        if (trainable(feature)) {
            if (ex.dp[u] == null)
                ex.dp[u] = new TIntDoubleHashMap();
            Dictionary.increment(ex.dp[u], flid, c.apr.alpha * dru.get(flid));
        }
        double drowSum = 0;
        for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
            int v = ex.getGraph().edge_dest[eid];
            if (hasFeature(ex.getGraph(), eid, flid)) {
                //g.getFeatures(u, v).containsKey(feature)) {
                drowSum += c.squashingFunction.computeDerivative(unwrappedDotP.get(v));
            }
        }
        drowSums.put(flid, drowSum);
        // update dr for the pushed vertex, storing dr temporarily for the calculation of dr for the other vertices:
        prevdr.put(flid, dru.get(flid));
        dru.put(flid, dru.get(flid) * (1 - c.apr.alpha) * stayProb);
    }
    // update dr for other vertices:
    for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
        int v = ex.getGraph().edge_dest[eid];
        double dotP = c.squashingFunction.edgeWeight(unwrappedDotP.get(v));
        double ddotP = c.squashingFunction.computeDerivative(unwrappedDotP.get(v));
        for (String feature : exampleFeatures) {
            int flid = ex.getGraph().featureLibrary.getId(feature);
            int contained = hasFeature(ex.getGraph(), eid, flid) ? 1 : 0;
            if (ex.dr[v] == null)
                ex.dr[v] = new TIntDoubleHashMap();
            double vdr = Dictionary.safeGet(ex.dr[v], flid, 0.0);
            // whoa this is pretty gross.
            vdr += (1 - stayProb) * (1 - c.apr.alpha) * ((prevdr.get(flid) * dotP / rowSum) + (ex.r[u] * ((contained * ddotP * rowSum) - (dotP * drowSums.get(flid))) / (rowSum * rowSum)));
            ex.dr[v].put(flid, vdr);
        }
    }
    // update r for all affected vertices:
    double ru = ex.r[u];
    ex.r[u] = ru * stayProb * (1 - c.apr.alpha);
    for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
        int v = ex.getGraph().edge_dest[eid];
        // calculate edge weight on v:
        double dotP = c.squashingFunction.edgeWeight(unwrappedDotP.get(v));
        ex.r[v] += (1 - stayProb) * (1 - c.apr.alpha) * (dotP / rowSum) * ru;
    }
}
Also used : TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleMap(gnu.trove.map.TIntDoubleMap)

Example 18 with TIntDoubleHashMap

use of gnu.trove.map.hash.TIntDoubleHashMap in project ProPPR by TeamCohen.

the class LightweightStateGraph method getFeatures.

public Map<Feature, Double> getFeatures(State u, State v) {
    int ui = this.nodeTab.getId(u), vi = this.nodeTab.getId(v);
    if (!edgeFeatureDict.containsKey(ui))
        return DEFAULT_FD;
    TIntObjectHashMap<TIntDoubleHashMap> fu = edgeFeatureDict.get(ui);
    if (!fu.containsKey(vi))
        return DEFAULT_FD;
    TIntDoubleHashMap fuvi = fu.get(vi);
    final HashMap<Feature, Double> ret = new HashMap<Feature, Double>();
    fuvi.forEachEntry(new TIntDoubleProcedure() {

        @Override
        public boolean execute(int fi, double wt) {
            ret.put(featureTab.getSymbol(fi), wt);
            return true;
        }
    });
    return ret;
}
Also used : TIntObjectHashMap(gnu.trove.map.hash.TIntObjectHashMap) HashMap(java.util.HashMap) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleHashMap(gnu.trove.map.hash.TIntDoubleHashMap) TIntDoubleProcedure(gnu.trove.procedure.TIntDoubleProcedure) Feature(edu.cmu.ml.proppr.prove.wam.Feature)

Aggregations

TIntDoubleHashMap (gnu.trove.map.hash.TIntDoubleHashMap)18 TIntDoubleMap (gnu.trove.map.TIntDoubleMap)9 PosNegRWExample (edu.cmu.ml.proppr.examples.PosNegRWExample)3 PprExample (edu.cmu.ml.proppr.examples.PprExample)2 RegularizationSchedule (edu.cmu.ml.proppr.learn.RegularizationSchedule)2 RegularizeL2 (edu.cmu.ml.proppr.learn.RegularizeL2)2 SRW (edu.cmu.ml.proppr.learn.SRW)2 Feature (edu.cmu.ml.proppr.prove.wam.Feature)2 TIntDoubleIterator (gnu.trove.iterator.TIntDoubleIterator)2 TIntObjectHashMap (gnu.trove.map.hash.TIntObjectHashMap)2 TIntDoubleProcedure (gnu.trove.procedure.TIntDoubleProcedure)2 HashMap (java.util.HashMap)2 Before (org.junit.Before)2 Test (org.junit.Test)2 GrounderTest (edu.cmu.ml.proppr.GrounderTest)1 GraphFormatException (edu.cmu.ml.proppr.graph.GraphFormatException)1 LearningGraph (edu.cmu.ml.proppr.graph.LearningGraph)1 Outlink (edu.cmu.ml.proppr.prove.wam.Outlink)1 SparseGraphPluginTest (edu.cmu.ml.proppr.prove.wam.plugins.SparseGraphPluginTest)1 StatusLogger (edu.cmu.ml.proppr.util.StatusLogger)1