use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class ClusterLabels method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
fitModel(config);
plot(config);
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
VERBOSE = config.getBoolean("output.verbose");
new File(config.getString("output.dir")).mkdirs();
if (config.getBoolean("tune")) {
System.out.println("============================================================");
System.out.println("Start hyper parameter tuning");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
List<TuneResult> tuneResults = new ArrayList<>();
List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
List<Double> variances = config.getDoubles("tune.variance.candidates");
List<Integer> components = config.getIntegers("tune.numComponents.candidates");
for (double variance : variances) {
for (int component : components) {
StopWatch stopWatch1 = new StopWatch();
stopWatch1.start();
HyperParameters hyperParameters = new HyperParameters();
hyperParameters.numComponents = component;
hyperParameters.variance = variance;
System.out.println("---------------------------");
System.out.println("Trying hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.variance = " + hyperParameters.variance);
TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
System.out.println("Validation performance = " + tuneResult.performance);
tuneResults.add(tuneResult);
System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
}
}
Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
TuneResult best;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_f1":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_hamming_loss":
best = tuneResults.stream().min(comparator).get();
break;
default:
throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
System.out.println("---------------------------");
System.out.println("Hyper parameter tuning done.");
System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
System.out.println("Best validation performance = " + best.performance);
System.out.println("Best hyper parameters:");
System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
System.out.println("train.variance = " + best.hyperParameters.variance);
System.out.println("train.iterations = " + best.hyperParameters.iterations);
Config tunedHypers = best.hyperParameters.asConfig();
tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
System.out.println("============================================================");
}
if (config.getBoolean("train")) {
System.out.println("============================================================");
if (config.getBoolean("train.useTunedHyperParameters")) {
File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
if (!hyperFile.exists()) {
System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
System.exit(1);
}
Config tunedHypers = new Config(hyperFile);
HyperParameters hyperParameters = new HyperParameters(tunedHypers);
System.out.println("Start training with tuned hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.variance = " + hyperParameters.variance);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
} else {
HyperParameters hyperParameters = new HyperParameters(config);
System.out.println("Start training with given hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.variance = " + hyperParameters.variance);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
}
System.out.println("============================================================");
}
if (config.getBoolean("test")) {
System.out.println("============================================================");
test(config);
System.out.println("============================================================");
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method reportGeneral.
private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("computing other predictor-independent metrics");
String output = config.getString("output.dir");
File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
double labelProbThreshold = config.getDouble("report.labelProbThreshold");
try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
br.newLine();
}
}
System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
// Here we do not use approximation
double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
System.out.println("average log likelihood of the test ground truth label sets = " + average);
if (!unobservedLabels.isEmpty()) {
System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
System.out.println(ListUtil.toSimpleString(unobservedLabels));
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMEN method reportHammingPrediction.
private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
String output = config.getString("output.dir");
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = marginalPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance Hamming loss optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMGB method reportGeneral.
private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("computing other predictor-independent metrics");
String output = config.getString("output.dir");
File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
double labelProbThreshold = config.getDouble("report.labelProbThreshold");
try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
br.newLine();
}
}
System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
// Here we do not use approximation
double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
System.out.println("average log likelihood of the test ground truth label sets = " + average);
if (!unobservedLabels.isEmpty()) {
System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
System.out.println(ListUtil.toSimpleString(unobservedLabels));
}
}
Aggregations