use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class KLLossTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
List<MultiLabel> support = cmlcrf.getSupportCombinations();
double[][] targetDistribution = new double[dataSet.getNumDataPoints()][cmlcrf.getSupportCombinations().size()];
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
for (int c = 0; c < support.size(); c++) {
MultiLabel multiLabel = dataSet.getMultiLabels()[i];
if (support.get(c).equals(multiLabel)) {
targetDistribution[i][c] = 1;
}
}
}
System.out.println("start");
KLLoss klLoss = new KLLoss(cmlcrf, dataSet, targetDistribution, 1);
cmlcrf.setConsiderPair(true);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(klLoss);
for (int i = 0; i < 200; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(klLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class LogRiskOptimizerTest method test3.
private static void test3() {
MultiLabelClfDataSet train = MultiLabelSynthesizer.crfSample();
MultiLabelClfDataSet test = MultiLabelSynthesizer.crfSample();
CMLCRF cmlcrf = new CMLCRF(train);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
MLScorer fScorer = new FScorer();
MLScorer accScorer = new AccScorer();
SubsetAccPredictor plugInAcc = new SubsetAccPredictor(cmlcrf);
LogRiskOptimizer fOptimizer = new LogRiskOptimizer(train, fScorer, cmlcrf, 1, false, false, 1, 1);
InstanceF1Predictor plugInF1 = new InstanceF1Predictor(cmlcrf);
System.out.println(cmlcrf);
System.out.println("training performance acc");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance acc");
System.out.println(new MLMeasures(cmlcrf, test));
System.out.println("training performance f1");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test performance f1");
System.out.println(new MLMeasures(plugInF1, test));
LogRiskOptimizer accOptimizer = new LogRiskOptimizer(train, accScorer, cmlcrf, 1, false, false, 1, 1);
accOptimizer.iterate();
System.out.println("after ML estimation");
System.out.println("training with Acc predictor");
System.out.println(new MLMeasures(plugInAcc, train));
System.out.println("training with F1 predictor");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test with Acc predictor");
System.out.println(new MLMeasures(plugInAcc, test));
System.out.println("test with F1 predictor");
System.out.println(new MLMeasures(plugInF1, test));
System.out.println(cmlcrf);
System.out.println(fOptimizer.objectiveDetail());
while (!fOptimizer.getTerminator().shouldTerminate()) {
System.out.println("------------");
fOptimizer.iterate();
System.out.println(fOptimizer.objectiveDetail());
System.out.println("training performance acc");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance acc");
System.out.println(new MLMeasures(cmlcrf, test));
System.out.println("training performance f1");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test performance f1");
System.out.println(new MLMeasures(plugInF1, test));
System.out.println(cmlcrf);
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class LogRiskOptimizerTest method test4.
private static void test4() {
MultiLabelClfDataSet train = MultiLabelSynthesizer.crfArgmaxHide();
MultiLabelClfDataSet test = MultiLabelSynthesizer.crfArgmaxHide();
CMLCRF cmlcrf = new CMLCRF(train);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
MLScorer fScorer = new FScorer();
MLScorer accScorer = new AccScorer();
SubsetAccPredictor plugInAcc = new SubsetAccPredictor(cmlcrf);
LogRiskOptimizer fOptimizer = new LogRiskOptimizer(train, fScorer, cmlcrf, 1, false, false, 1, 1);
InstanceF1Predictor plugInF1 = new InstanceF1Predictor(cmlcrf);
System.out.println(cmlcrf);
System.out.println("training performance acc");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance acc");
System.out.println(new MLMeasures(cmlcrf, test));
System.out.println("training performance f1");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test performance f1");
System.out.println(new MLMeasures(plugInF1, test));
System.out.println(fOptimizer.objectiveDetail());
LogRiskOptimizer accOptimizer = new LogRiskOptimizer(train, accScorer, cmlcrf, 1, false, false, 1, 1);
accOptimizer.iterate();
System.out.println("after ML estimation");
System.out.println("training with Acc predictor");
System.out.println(new MLMeasures(plugInAcc, train));
System.out.println("training with F1 predictor");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test with Acc predictor");
System.out.println(new MLMeasures(plugInAcc, test));
System.out.println("test with F1 predictor");
System.out.println(new MLMeasures(plugInF1, test));
System.out.println(cmlcrf);
System.out.println(fOptimizer.objectiveDetail());
while (!fOptimizer.getTerminator().shouldTerminate()) {
System.out.println("------------");
fOptimizer.iterate();
System.out.println(fOptimizer.objectiveDetail());
System.out.println("training performance acc");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance acc");
System.out.println(new MLMeasures(cmlcrf, test));
System.out.println("training performance f1");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test performance f1");
System.out.println(new MLMeasures(plugInF1, test));
System.out.println(cmlcrf);
}
}
Aggregations