use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBMGB method train.
private static void train(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet) throws Exception {
List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
if (!unobservedLabels.isEmpty()) {
System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
System.out.println(ListUtil.toSimpleString(unobservedLabels));
}
String output = config.getString("output.dir");
FileUtils.writeStringToFile(new File(output, "unobserved_labels.txt"), ListUtil.toSimpleString(unobservedLabels));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
CBM cbm = newCBM(config, trainSet, hyperParameters);
GBCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
System.out.println("Initializing the model");
optimizer.initialize();
System.out.println("Initialization done");
for (int iter = 1; iter <= hyperParameters.iterations; iter++) {
System.out.println("Training progress: iteration " + iter);
optimizer.iterate();
}
System.out.println("training done!");
System.out.println("time spent on training = " + stopWatch);
Serialization.serialize(cbm, new File(output, "model"));
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
Serialization.serialize(support, new File(output, "support"));
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBMGB method reportHammingPrediction.
private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
String output = config.getString("output.dir");
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = marginalPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance Hamming loss optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBMGB method reportAccPrediction.
private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
String output = config.getString("output.dir");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = accPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance set accuracy optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBMGB method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class Precision method precision.
@Deprecated
public static /**
* From "Mining Multi-label Data by Grigorios Tsoumakas".
* @param multiLabels
* @param predictions
* @return
*/
double precision(MultiLabel[] multiLabels, MultiLabel[] predictions) {
double p = 0.0;
for (int i = 0; i < multiLabels.length; i++) {
MultiLabel label = multiLabels[i];
MultiLabel prediction = predictions[i];
if (prediction.getMatchedLabels().size() == 0) {
p += 1.0;
} else {
p += MultiLabel.intersection(label, prediction).size() * 1.0 / prediction.getMatchedLabels().size();
}
}
return p / multiLabels.length;
}
Aggregations