use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class PairwiseL2SqLoss method computeLossGradient.
@Override
public int computeLossGradient(ParamVector params, PosNegRWExample example, TIntDoubleMap gradient, LossData lossdata, SRWOptions c) {
PosNegRWExample ex = (PosNegRWExample) example;
int nonzero = 0;
// add empirical loss gradient term
// positive examples
double pmax = 0;
for (int b : ex.getNegList()) {
for (int a : ex.getPosList()) {
double delta = ex.p[b] - ex.p[a];
int[] keys = getKeys(ex.dp[b], ex.dp[a]);
for (int feature : keys) {
double db = ex.dp[b].get(feature);
if (db != 0.0)
nonzero++;
double da = ex.dp[a].get(feature);
if (da != 0.0)
nonzero++;
double del = derivLoss(delta) * (db - da);
gradient.adjustOrPutValue(feature, del, del);
}
if (log.isDebugEnabled())
log.debug("+pa=" + ex.p[a] + " pb = " + ex.p[b]);
lossdata.add(LOSS.L2, loss(delta));
}
}
return nonzero;
}
use of edu.cmu.ml.proppr.examples.PosNegRWExample in project ProPPR by TeamCohen.
the class CachingTrainer method train.
@Override
public ParamVector<String, ?> train(SymbolTable<String> masterFeatures, Iterable<String> exampleFile, LearningGraphBuilder builder, ParamVector<String, ?> initialParamVec, int numEpochs) {
ArrayList<PosNegRWExample> examples = new ArrayList<PosNegRWExample>();
RWExampleParser parser = new RWExampleParser();
if (masterFeatures.size() > 0)
LearningGraphBuilder.setFeatures(masterFeatures);
int id = 0;
StatusLogger stattime = new StatusLogger();
TrainingStatistics total = new TrainingStatistics();
boolean logged = false;
for (String s : exampleFile) {
total.updateReadingStatistics(stattime.sinceLast());
id++;
try {
stattime.tick();
PosNegRWExample ex = parser.parse(s, builder, masterLearner);
total.updateParsingStatistics(stattime.sinceLast());
examples.add(ex);
if (status.due()) {
log.info("Parsed " + id + " ...");
logged = true;
}
} catch (GraphFormatException e) {
log.error("Trouble with #" + id, e);
}
stattime.tick();
}
if (logged)
log.info("Total parsed: " + id);
return trainCached(examples, builder, initialParamVec, numEpochs, total);
}
Aggregations