use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class MLLogisticTrainerTest method test1.
private static void test1() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
MLLogisticRegression mlLogisticRegression = new MLLogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures(), assignments);
MLLogisticLoss function = new MLLogisticLoss(mlLogisticRegression, dataSet, 10000);
LBFGS lbfgs = new LBFGS(function);
lbfgs.getTerminator().setRelativeEpsilon(0.01);
lbfgs.setHistory(5);
for (int i = 0; i < 100; i++) {
System.out.println(function.getValue());
// System.out.println(Accuracy.accuracy(mlLogisticRegression,dataSet));
lbfgs.iterate();
}
System.out.println(Accuracy.accuracy(mlLogisticRegression, dataSet));
System.out.println(Overlap.overlap(mlLogisticRegression, dataSet));
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class MLLogisticTrainerTest method test3.
/**
* * add a fake label in spam data set, if x=spam and x_0<0.1, also label it as 2
* @throws Exception
*/
private static void test3() throws Exception {
ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
int numDataPoints = singleLabeldataSet.getNumDataPoints();
int numFeatures = singleLabeldataSet.getNumFeatures();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(3).build();
int[] labels = singleLabeldataSet.getLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.addLabel(i, labels[i]);
if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
dataSet.addLabel(i, 2);
}
for (int j = 0; j < numFeatures; j++) {
double value = singleLabeldataSet.getRow(i).get(j);
dataSet.setFeatureValue(i, j, value);
}
}
List<MultiLabel> assignments = new ArrayList<>();
assignments.add(new MultiLabel().addLabel(0));
assignments.add(new MultiLabel().addLabel(1));
assignments.add(new MultiLabel().addLabel(1).addLabel(2));
MLLogisticRegression mlLogisticRegression = new MLLogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures(), assignments);
MLLogisticLoss function = new MLLogisticLoss(mlLogisticRegression, dataSet, 10000);
LBFGS lbfgs = new LBFGS(function);
lbfgs.getTerminator().setRelativeEpsilon(0.01);
lbfgs.setHistory(5);
for (int i = 0; i < 1000; i++) {
// System.out.println(function.getValue());
System.out.println(Accuracy.accuracy(mlLogisticRegression, dataSet));
lbfgs.iterate();
}
System.out.println(Accuracy.accuracy(mlLogisticRegression, dataSet));
System.out.println(Overlap.overlap(mlLogisticRegression, dataSet));
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class CMLCRFTest method test2.
public static void test2() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
double gaussianVariance = config.getDouble("gaussianVariance");
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf;
MultiLabel[] predTrain;
MultiLabel[] predTest;
if (config.getBoolean("train.warmStart")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("loading model:");
System.out.println(cmlcrf);
} else {
cmlcrf = new CMLCRF(trainSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
if (config.getBoolean("isLBFGS")) {
LBFGS optimizer = new LBFGS(crfLoss);
optimizer.getTerminator().setAbsoluteEpsilon(0.1);
for (int i = 0; i < config.getInt("numRounds"); i++) {
optimizer.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
}
} else {
GradientDescent optimizer = new GradientDescent(crfLoss);
for (int i = 0; i < config.getInt("numRounds"); i++) {
optimizer.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
}
}
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class CMLCRFTest method test5.
private static void test5() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(false);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 5; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
CRFLoss crfLoss2 = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(true);
LBFGS optimizer2 = new LBFGS(crfLoss2);
for (int i = 0; i < 50; i++) {
System.out.println("consider pairs");
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer2.iterate();
System.out.println(crfLoss2.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
}
use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.
the class CMLCRFTest method test3.
private static void test3() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 50; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
}
Aggregations