use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class CBMSOptimizer method updateBinaryLogisticRegression.
// todo pay attention to parallelism
private void updateBinaryLogisticRegression(int labelIndex) {
AugmentedLRLoss loss = new AugmentedLRLoss(dataSet, labelIndex, gammas, cbms.getBinaryClassifiers()[labelIndex], priorVarianceBinary, componentWeightsVariance);
LBFGS lbfgs = new LBFGS(loss);
// todo
lbfgs.getTerminator().setMaxIteration(numBinaryParaUpdates);
lbfgs.getTerminator().setGoal(Terminator.Goal.MINIMIZE);
lbfgs.optimize();
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class NoiseOptimizer method updateModel.
private void updateModel() {
if (logger.isDebugEnabled()) {
logger.debug("start updateModel()");
}
KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
// todo
Optimizer opt = null;
switch(optimizer) {
case "LBFGS":
opt = new LBFGS(klLoss);
break;
case "GD":
opt = new GradientDescent(klLoss);
break;
default:
throw new IllegalArgumentException("unknown");
}
opt.optimize();
if (logger.isDebugEnabled()) {
logger.debug("finish updateModel()");
}
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class NoiseOptimizer method updateModelPartial.
private void updateModelPartial(int modelIterations) {
if (logger.isDebugEnabled()) {
logger.debug("start updateModelPartial()");
}
KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
// todo
Optimizer opt = null;
switch(optimizer) {
case "LBFGS":
opt = new LBFGS(klLoss);
break;
case "GD":
opt = new GradientDescent(klLoss);
break;
default:
throw new IllegalArgumentException("unknown");
}
opt.getTerminator().setMaxIteration(modelIterations);
opt.optimize();
if (logger.isDebugEnabled()) {
logger.debug("finish updateModelPartial()");
}
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class LogRiskOptimizer method updateModel.
private void updateModel() {
if (logger.isDebugEnabled()) {
logger.debug("start updateModel()");
}
KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
// todo
Optimizer opt = null;
switch(optimizer) {
case "LBFGS":
opt = new LBFGS(klLoss);
break;
case "GD":
opt = new GradientDescent(klLoss);
break;
default:
throw new IllegalArgumentException("unknown");
}
opt.optimize();
if (logger.isDebugEnabled()) {
logger.debug("finish updateModel()");
}
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class KLLossTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
List<MultiLabel> support = cmlcrf.getSupportCombinations();
double[][] targetDistribution = new double[dataSet.getNumDataPoints()][cmlcrf.getSupportCombinations().size()];
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
for (int c = 0; c < support.size(); c++) {
MultiLabel multiLabel = dataSet.getMultiLabels()[i];
if (support.get(c).equals(multiLabel)) {
targetDistribution[i][c] = 1;
}
}
}
System.out.println("start");
KLLoss klLoss = new KLLoss(cmlcrf, dataSet, targetDistribution, 1);
cmlcrf.setConsiderPair(true);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(klLoss);
for (int i = 0; i < 200; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(klLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
}
}
Aggregations