use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMGB method reportHammingPrediction.
private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
String output = config.getString("output.dir");
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = marginalPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance Hamming loss optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMGB method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
VERBOSE = config.getBoolean("output.verbose");
new File(config.getString("output.dir")).mkdirs();
if (config.getBoolean("tune")) {
System.out.println("============================================================");
System.out.println("Start hyper parameter tuning");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
List<TuneResult> tuneResults = new ArrayList<>();
List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
List<Integer> leaveNums = config.getIntegers("tune.numLeaves.candidates");
List<Integer> components = config.getIntegers("tune.numComponents.candidates");
for (int numLeaves : leaveNums) {
for (int component : components) {
StopWatch stopWatch1 = new StopWatch();
stopWatch1.start();
HyperParameters hyperParameters = new HyperParameters();
hyperParameters.numComponents = component;
hyperParameters.numLeaves = numLeaves;
System.out.println("---------------------------");
System.out.println("Trying hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
System.out.println("Validation performance = " + tuneResult.performance);
tuneResults.add(tuneResult);
System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
}
}
Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
TuneResult best;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_f1":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_hamming_loss":
best = tuneResults.stream().min(comparator).get();
break;
default:
throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
System.out.println("---------------------------");
System.out.println("Hyper parameter tuning done.");
System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
System.out.println("Best validation performance = " + best.performance);
System.out.println("Best hyper parameters:");
System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
System.out.println("train.numLeaves = " + best.hyperParameters.numLeaves);
System.out.println("train.iterations = " + best.hyperParameters.iterations);
Config tunedHypers = best.hyperParameters.asConfig();
tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
System.out.println("============================================================");
}
if (config.getBoolean("train")) {
System.out.println("============================================================");
if (config.getBoolean("train.useTunedHyperParameters")) {
File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
if (!hyperFile.exists()) {
System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
System.exit(1);
}
Config tunedHypers = new Config(hyperFile);
HyperParameters hyperParameters = new HyperParameters(tunedHypers);
System.out.println("Start training with tuned hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
} else {
HyperParameters hyperParameters = new HyperParameters(config);
System.out.println("Start training with given hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
}
System.out.println("============================================================");
}
if (config.getBoolean("test")) {
System.out.println("============================================================");
test(config);
System.out.println("============================================================");
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMGB method reportAccPrediction.
private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
String output = config.getString("output.dir");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = accPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance set accuracy optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMGB method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class MultiLabelSynthesizer method crfArgmax.
public static MultiLabelClfDataSet crfArgmax() {
int numData = 1000;
int numClass = 4;
int numFeature = 10;
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
List<MultiLabel> support = Enumerator.enumerate(numClass);
CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
// assign labels
for (int i = 0; i < numData; i++) {
MultiLabel label = predictor.predict(dataSet.getRow(i));
dataSet.setLabels(i, label);
}
return dataSet;
}
Aggregations