use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.
the class SRW method gradient.
protected TIntDoubleMap gradient(ParamVector<String, ?> params, PosNegRWExample example) {
PosNegRWExample ex = (PosNegRWExample) example;
Set<String> features = this.regularizer.localFeatures(params, ex.getGraph());
TIntDoubleMap gradient = new TIntDoubleHashMap(features.size());
// add regularization term
regularization(params, ex, gradient);
int nonzero = lossf.computeLossGradient(params, example, gradient, this.cumloss, c);
for (int i : gradient.keys()) {
gradient.put(i, gradient.get(i) / example.length());
}
if (nonzero == 0) {
this.zeroGradientData.numZero++;
if (this.zeroGradientData.numZero < MAX_ZERO_LOGS) {
this.zeroGradientData.examples.append("\n").append(ex);
}
}
return gradient;
}
use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.
the class SRW method sgd.
/** edits params */
protected void sgd(ParamVector<String, ?> params, PosNegRWExample ex) {
TIntDoubleMap gradient = gradient(params, ex);
// apply gradient to param vector
for (TIntDoubleIterator grad = gradient.iterator(); grad.hasNext(); ) {
grad.advance();
if (grad.value() == 0)
continue;
String feature = ex.getGraph().featureLibrary.getSymbol(grad.key());
if (trainable(feature)) {
params.adjustValue(feature, -learningRate(feature) * grad.value());
if (params.get(feature).isInfinite()) {
log.warn("Infinity at " + feature + "; gradient " + grad.value());
}
}
}
}
use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.
the class DprSRW method inference.
@Override
protected void inference(ParamVector<String, ?> params, PosNegRWExample example, StatusLogger status) {
DprExample ex = (DprExample) example;
// startNode maps node->weight
TIntDoubleMap query = ex.getQueryVec();
if (query.size() > 1)
throw new UnsupportedOperationException("Can't do multi-node queries");
// maps storing the probability and remainder weights of the nodes:
ex.p = new double[ex.getGraph().node_hi];
ex.r = new double[ex.getGraph().node_hi];
// initializing the above maps:
Arrays.fill(ex.p, 0.0);
Arrays.fill(ex.r, 0.0);
for (TIntDoubleIterator it = query.iterator(); it.hasNext(); ) {
it.advance();
ex.r[it.key()] = it.value();
}
// maps storing the gradients of p and r for each node:
ex.dp = new TIntDoubleMap[ex.getGraph().node_hi];
ex.dr = new TIntDoubleMap[ex.getGraph().node_hi];
// initializing the above maps:
// for(int node : example.getGraph().getNodes()) {
// dp.put(node, new TObjectDoubleHashMap<String>());
// dr.put(node, new TObjectDoubleHashMap<String>());
// for(String feature : (example.getGraph().getFeatureSet()))
// {
// dp.get(node).put(feature, 0.0);
// dr.get(node).put(feature, 0.0);
// }
// }
// APR Algorithm:
int completeCount = 0;
while (completeCount < ex.getGraph().node_hi) {
if (log.isDebugEnabled())
log.debug("Starting pass");
completeCount = 0;
for (int u = 0; u < ex.getGraph().node_hi; u++) {
double ru = ex.r[u];
int udeg = ex.getGraph().node_near_hi[u] - ex.getGraph().node_near_lo[u];
if (ru / (double) udeg > c.apr.epsilon)
while (ru / udeg > c.apr.epsilon) {
this.push(u, params, ex);
if (ex.r[u] > ru)
throw new IllegalStateException("r increasing! :(");
ru = ex.r[u];
}
else {
completeCount++;
if (log.isDebugEnabled())
log.debug("Counting " + u);
}
}
if (log.isDebugEnabled())
log.debug(completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
else if (log.isInfoEnabled() && status.due(3))
log.info(Thread.currentThread() + " inference: " + completeCount + " of " + ex.getGraph().node_hi + " completed this pass");
}
// GradientComponents g = new GradientComponents();
// g.p = p;
// g.d = dp;
// return g;
}
use of gnu.trove.map.TIntDoubleMap in project ProPPR by TeamCohen.
the class DprSRW method push.
/**
* Simulates a single lazy random walk step on the input vertex
* @param u the vertex to be 'pushed'
* @param p
* @param r
* @param g
* @param paramVec
* @param dp
* @param dr
*/
public void push(int u, ParamVector<String, ?> paramVec, DprExample ex) {
log.debug("Pushing " + u);
// update p for the pushed node:
ex.p[u] += c.apr.alpha * ex.r[u];
if (ex.dr[u] == null)
ex.dr[u] = new TIntDoubleHashMap();
TIntDoubleMap dru = ex.dr[u];
TIntDoubleMap unwrappedDotP = new TIntDoubleHashMap();
for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
int v = ex.getGraph().edge_dest[eid];
unwrappedDotP.put(v, dotP(ex.getGraph(), eid, paramVec));
}
// calculate the sum of the weights (raised to exp) of the edges adjacent to the input node:
double rowSum = this.totalEdgeProbWeight(ex.getGraph(), u, paramVec);
// calculate the gradients of the rowSums (needed for the calculation of the gradient of r):
TIntDoubleMap drowSums = new TIntDoubleHashMap();
TIntDoubleMap prevdr = new TIntDoubleHashMap();
Set<String> exampleFeatures = ex.getGraph().getFeatureSet();
for (String feature : exampleFeatures) {
int flid = ex.getGraph().featureLibrary.getId(feature);
// simultaneously update the dp for the pushed node:
if (trainable(feature)) {
if (ex.dp[u] == null)
ex.dp[u] = new TIntDoubleHashMap();
Dictionary.increment(ex.dp[u], flid, c.apr.alpha * dru.get(flid));
}
double drowSum = 0;
for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
int v = ex.getGraph().edge_dest[eid];
if (hasFeature(ex.getGraph(), eid, flid)) {
//g.getFeatures(u, v).containsKey(feature)) {
drowSum += c.squashingFunction.computeDerivative(unwrappedDotP.get(v));
}
}
drowSums.put(flid, drowSum);
// update dr for the pushed vertex, storing dr temporarily for the calculation of dr for the other vertices:
prevdr.put(flid, dru.get(flid));
dru.put(flid, dru.get(flid) * (1 - c.apr.alpha) * stayProb);
}
// update dr for other vertices:
for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
int v = ex.getGraph().edge_dest[eid];
double dotP = c.squashingFunction.edgeWeight(unwrappedDotP.get(v));
double ddotP = c.squashingFunction.computeDerivative(unwrappedDotP.get(v));
for (String feature : exampleFeatures) {
int flid = ex.getGraph().featureLibrary.getId(feature);
int contained = hasFeature(ex.getGraph(), eid, flid) ? 1 : 0;
if (ex.dr[v] == null)
ex.dr[v] = new TIntDoubleHashMap();
double vdr = Dictionary.safeGet(ex.dr[v], flid, 0.0);
// whoa this is pretty gross.
vdr += (1 - stayProb) * (1 - c.apr.alpha) * ((prevdr.get(flid) * dotP / rowSum) + (ex.r[u] * ((contained * ddotP * rowSum) - (dotP * drowSums.get(flid))) / (rowSum * rowSum)));
ex.dr[v].put(flid, vdr);
}
}
// update r for all affected vertices:
double ru = ex.r[u];
ex.r[u] = ru * stayProb * (1 - c.apr.alpha);
for (int eid = ex.getGraph().node_near_lo[u], xvi = 0; eid < ex.getGraph().node_near_hi[u]; eid++, xvi++) {
int v = ex.getGraph().edge_dest[eid];
// calculate edge weight on v:
double dotP = c.squashingFunction.edgeWeight(unwrappedDotP.get(v));
ex.r[v] += (1 - stayProb) * (1 - c.apr.alpha) * (dotP / rowSum) * ru;
}
}
Aggregations