use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class AugmentedLRLossTest method main.
public static void main(String[] args) throws Exception {
LoggerContext ctx = (LoggerContext) LogManager.getContext(false);
Configuration config = ctx.getConfiguration();
LoggerConfig loggerConfig = config.getLoggerConfig(LogManager.ROOT_LOGGER_NAME);
loggerConfig.setLevel(Level.DEBUG);
ctx.updateLoggers();
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
AugmentedLR augmentedLR = new AugmentedLR(dataSet.getNumFeatures(), 1);
double[][] gammas = new double[dataSet.getNumDataPoints()][1];
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
gammas[i][0] = 1;
}
AugmentedLRLoss loss = new AugmentedLRLoss(dataSet, 0, gammas, augmentedLR, 1, 1);
LBFGS lbfgs = new LBFGS(loss);
for (int i = 0; i < 100; i++) {
lbfgs.iterate();
System.out.println(loss.getValue());
}
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class App6 method train.
private static void train(Config config) throws Exception {
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(trainSet);
double gaussianVariance = config.getDouble("train.gaussianVariance");
cmlcrf.setConsiderPair(true);
CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
int maxIteration = config.getInt("train.maxIteration");
crfLoss.setRegularizeAll(true);
LBFGS optimizer = new LBFGS(crfLoss);
optimizer.getTerminator().setMaxIteration(maxIteration);
PluginPredictor<CMLCRF> predictor = null;
String predictTarget = config.getString("predict.target");
switch(predictTarget) {
case "subsetAccuracy":
predictor = new SubsetAccPredictor(cmlcrf);
break;
case "instanceFMeasure":
predictor = new InstanceF1Predictor(cmlcrf);
break;
default:
throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
}
int progressInterval = config.getInt("train.showProgress.interval");
System.out.println("start training");
int iteration = 0;
while (true) {
optimizer.iterate();
iteration += 1;
if (iteration % progressInterval == 0) {
System.out.println("iteration " + iteration);
System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
System.out.println("training performance:");
System.out.println(new MLMeasures(predictor, trainSet));
System.out.println("test performance:");
System.out.println(new MLMeasures(predictor, testSet));
String modelName = "model_crf";
String output = config.getString("output.folder");
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
if (optimizer.getTerminator().shouldTerminate()) {
System.out.println("iteration " + iteration);
System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
System.out.println("training performance:");
System.out.println(new MLMeasures(predictor, trainSet));
System.out.println("test performance:");
System.out.println(new MLMeasures(predictor, testSet));
System.out.println("training done!");
break;
}
}
String modelName = "model_crf";
String output = config.getString("output.folder");
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
MultiLabel[] predictions = cmlcrf.predict(trainSet);
File predictionFile = new File(output, "train_predictions.txt");
FileUtils.writeStringToFile(predictionFile, PrintUtil.toMutipleLines(predictions));
System.out.println("predictions on the training set are written to " + predictionFile.getAbsolutePath());
if (config.getBoolean("train.generateReports")) {
report(config, trainSet, "trainSet");
}
}
Aggregations