use of edu.neu.ccs.pyramid.optimization.GradientDescent in project pyramid by cheng-li.
the class LogRiskOptimizer method updateModel.
private void updateModel() {
if (logger.isDebugEnabled()) {
logger.debug("start updateModel()");
}
KLLoss klLoss = new KLLoss(crf, dataSet, targets, variance);
// todo
Optimizer opt = null;
switch(optimizer) {
case "LBFGS":
opt = new LBFGS(klLoss);
break;
case "GD":
opt = new GradientDescent(klLoss);
break;
default:
throw new IllegalArgumentException("unknown");
}
opt.optimize();
if (logger.isDebugEnabled()) {
logger.debug("finish updateModel()");
}
}
use of edu.neu.ccs.pyramid.optimization.GradientDescent in project pyramid by cheng-li.
the class CMLCRFTest method test2.
public static void test2() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
double gaussianVariance = config.getDouble("gaussianVariance");
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf;
MultiLabel[] predTrain;
MultiLabel[] predTest;
if (config.getBoolean("train.warmStart")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("loading model:");
System.out.println(cmlcrf);
} else {
cmlcrf = new CMLCRF(trainSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
if (config.getBoolean("isLBFGS")) {
LBFGS optimizer = new LBFGS(crfLoss);
optimizer.getTerminator().setAbsoluteEpsilon(0.1);
for (int i = 0; i < config.getInt("numRounds"); i++) {
optimizer.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
}
} else {
GradientDescent optimizer = new GradientDescent(crfLoss);
for (int i = 0; i < config.getInt("numRounds"); i++) {
optimizer.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
}
}
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
use of edu.neu.ccs.pyramid.optimization.GradientDescent in project pyramid by cheng-li.
the class LogisticRegressionTest method test1.
private static void test1() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, false);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, false);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
LogisticLoss function = new LogisticLoss(logisticRegression, dataSet, 1000, true);
GradientDescent gradientDescent = new GradientDescent(function);
gradientDescent.getLineSearcher().setInitialStepLength(1.0E-4);
gradientDescent.optimize();
System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
Aggregations