use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMGB method reportAccPrediction.
private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
String output = config.getString("output.dir");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = accPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance set accuracy optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMGB method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class App6 method test.
private static void test(Config config) throws Exception {
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
String modelName = "model_crf";
String output = config.getString("output.folder");
CMLCRF cmlcrf = (CMLCRF) Serialization.deserialize(new File(output, modelName));
PluginPredictor<CMLCRF> predictor = null;
String predictTarget = config.getString("predict.target");
switch(predictTarget) {
case "subsetAccuracy":
predictor = new SubsetAccPredictor(cmlcrf);
break;
case "instanceFMeasure":
predictor = new InstanceF1Predictor(cmlcrf);
break;
default:
throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
}
System.out.println("test performance:");
System.out.println(new MLMeasures(predictor, testSet));
MultiLabel[] predictions = cmlcrf.predict(testSet);
File predictionFile = new File(output, "test_predictions.txt");
FileUtils.writeStringToFile(predictionFile, PrintUtil.toMutipleLines(predictions));
System.out.println("predictions on the test set are written to " + predictionFile.getAbsolutePath());
report(config, testSet, "testSet");
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CMLCRFTest method test9.
private static void test9() {
MultiLabelClfDataSet train = MultiLabelSynthesizer.independentNoise();
MultiLabelClfDataSet test = MultiLabelSynthesizer.independent();
CMLCRF cmlcrf = new CMLCRF(train);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -1);
CRFLoss crfLoss = new CRFLoss(cmlcrf, train, 1);
System.out.println(cmlcrf);
System.out.println("initial loss = " + crfLoss.getValue());
System.out.println("training performance");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance");
System.out.println(new MLMeasures(cmlcrf, test));
LBFGS optimizer = new LBFGS(crfLoss);
while (!optimizer.getTerminator().shouldTerminate()) {
System.out.println("------------");
optimizer.iterate();
System.out.println(optimizer.getTerminator().getLastValue());
System.out.println("training performance");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance");
System.out.println(new MLMeasures(cmlcrf, test));
}
System.out.println(cmlcrf);
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CMLCRFTest method test8.
private static void test8() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf = new CMLCRF(trainSet);
BlockwiseCD blockwiseCD = new BlockwiseCD(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
MultiLabel[] predTrain;
MultiLabel[] predTest;
for (int i = 0; i < 10000; i++) {
blockwiseCD.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tobjective: " + String.format("%.4f", blockwiseCD.getValue()));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain F1 " + String.format("%.4f", FMeasure.f1(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.print("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest F1 " + String.format("%.4f", FMeasure.f1(testSet.getMultiLabels(), predTest)));
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
System.out.println("========== Train ==========\n");
System.out.println(measures);
System.out.println("========== Test ==========\n");
long startTimePred = System.nanoTime();
MultiLabel[] preds = cmlcrf.predict(testSet);
long stopTimePred = System.nanoTime();
long predTime = stopTimePred - startTimePred;
System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
System.out.println(new MLMeasures(cmlcrf, testSet));
System.out.println("\n\n");
InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
System.out.println("Plugin F1");
System.out.println(new MLMeasures(pluginF1, testSet));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
Aggregations