use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMLR method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMGB method tune.
private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
CBM cbm = newCBM(config, trainSet, hyperParameters);
EarlyStopper earlyStopper = loadNewEarlyStopper(config);
GBCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
optimizer.initialize();
MultiLabelClassifier classifier;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor;
break;
case "instance_f1":
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = pluginF1;
break;
case "instance_hamming_loss":
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = marginalPredictor;
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
int interval = config.getInt("tune.monitorInterval");
for (int iter = 1; true; iter++) {
if (VERBOSE) {
System.out.println("iteration " + iter);
}
optimizer.iterate();
if (iter % interval == 0) {
MLMeasures validMeasures = new MLMeasures(classifier, validSet);
if (VERBOSE) {
System.out.println("validation performance with " + predictTarget + " optimal predictor:");
System.out.println(validMeasures);
}
switch(predictTarget) {
case "instance_set_accuracy":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
break;
case "instance_f1":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
break;
case "instance_hamming_loss":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
}
if (earlyStopper.shouldStop()) {
if (VERBOSE) {
System.out.println("Early Stopper: the training should stop now!");
}
break;
}
}
}
if (VERBOSE) {
System.out.println("done!");
}
hyperParameters.iterations = earlyStopper.getBestIteration();
TuneResult tuneResult = new TuneResult();
tuneResult.hyperParameters = hyperParameters;
tuneResult.performance = earlyStopper.getBestValue();
return tuneResult;
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMEN method reportHammingPrediction.
private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
String output = config.getString("output.dir");
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = marginalPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance Hamming loss optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class SparseCBMOptimzerTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
int numComponents = 10;
CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setMultiClassClassifierType("lr").setBinaryClassifierType("lr").build();
SparseCBMOptimzer optimzer = new SparseCBMOptimzer(cbm, dataSet);
optimzer.initalizeGammaByBM();
optimzer.updateMultiClassLR();
optimzer.updateAllBinary();
// System.out.println(new MLMeasures(cbm, dataSet));
System.out.println("test");
System.out.println(new MLMeasures(cbm, testSet));
System.out.println("update gamma");
optimzer.updateGamma();
optimzer.updateMultiClassLR();
optimzer.updateAllBinary();
// System.out.println(new MLMeasures(cbm, dataSet));
System.out.println("test");
System.out.println(new MLMeasures(cbm, testSet));
System.out.println("update gamma again");
optimzer.updateGamma();
optimzer.updateMultiClassLR();
optimzer.updateAllBinary();
// System.out.println(new MLMeasures(cbm, dataSet));
System.out.println("test");
System.out.println(new MLMeasures(cbm, testSet));
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CMLCRFTest method test7.
private static void test7() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf = null;
if (config.getString("train.warmStart").equals("true")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("loading model:");
System.out.println(cmlcrf);
} else if (config.getString("train.warmStart").equals("auto")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("retrain model:");
CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
} else if (config.getString("train.warmStart").equals("false")) {
cmlcrf = new CMLCRF(trainSet);
cmlcrf.setConsiderPair(config.getBoolean("considerLabelPair"));
CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
System.out.println("========== Train ==========\n");
System.out.println(measures);
System.out.println("========== Test ==========\n");
long startTimePred = System.nanoTime();
MultiLabel[] preds = cmlcrf.predict(testSet);
long stopTimePred = System.nanoTime();
long predTime = stopTimePred - startTimePred;
System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
System.out.println(new MLMeasures(cmlcrf, testSet));
System.out.println("\n\n");
InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
System.out.println("Plugin F1");
System.out.println(new MLMeasures(pluginF1, testSet));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
Aggregations