use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.
the class CMLCRFTest method test5.
private static void test5() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(false);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 5; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
CRFLoss crfLoss2 = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(true);
LBFGS optimizer2 = new LBFGS(crfLoss2);
for (int i = 0; i < 50; i++) {
System.out.println("consider pairs");
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer2.iterate();
System.out.println(crfLoss2.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.
the class CMLCRFTest method test3.
private static void test3() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 50; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
}
Aggregations