use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method reportGeneral.
private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("computing other predictor-independent metrics");
String output = config.getString("output.dir");
File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
double labelProbThreshold = config.getDouble("report.labelProbThreshold");
try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
br.newLine();
}
}
System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
// Here we do not use approximation
double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
System.out.println("average log likelihood of the test ground truth label sets = " + average);
if (!unobservedLabels.isEmpty()) {
System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
System.out.println(ListUtil.toSimpleString(unobservedLabels));
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMEN method reportHammingPrediction.
private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
String output = config.getString("output.dir");
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = marginalPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance Hamming loss optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMGB method reportGeneral.
private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("computing other predictor-independent metrics");
String output = config.getString("output.dir");
File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
double labelProbThreshold = config.getDouble("report.labelProbThreshold");
try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
br.newLine();
}
}
System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
// Here we do not use approximation
double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
System.out.println("average log likelihood of the test ground truth label sets = " + average);
if (!unobservedLabels.isEmpty()) {
System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
System.out.println(ListUtil.toSimpleString(unobservedLabels));
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App1 method loadAugmentedLabelTranslator.
static LabelTranslator loadAugmentedLabelTranslator(Config config, MultiLabelIndex index, String[] testIndexIds, LabelTranslator trainLabelTranslator, Logger logger) {
File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
Config savedConfig = new Config(new File(metaDataFolder, "saved_config_app1"));
List<String> extLabels = new ArrayList<>();
for (int i = 0; i < trainLabelTranslator.getNumClasses(); i++) {
extLabels.add(trainLabelTranslator.toExtLabel(i));
}
Collection<Terms.Bucket> buckets = index.termAggregation(savedConfig.getString("train.label.field"), testIndexIds);
if (savedConfig.getBoolean("train.label.filter")) {
String prefix = savedConfig.getString("train.label.filter.prefix");
buckets = buckets.stream().filter(bucket -> bucket.getKey().startsWith(prefix)).collect(Collectors.toList());
}
List<String> newLabels = new ArrayList<>();
logger.info("label distribution in data set:");
StringBuilder stringBuilder = new StringBuilder();
for (Terms.Bucket bucket : buckets) {
stringBuilder.append(bucket.getKey());
stringBuilder.append(":");
stringBuilder.append(bucket.getDocCount());
stringBuilder.append(", ");
if (!extLabels.contains(bucket.getKey())) {
extLabels.add(bucket.getKey());
newLabels.add(bucket.getKey());
}
}
logger.info(stringBuilder.toString());
if (!newLabels.isEmpty()) {
logger.warning("found new labels in data set: " + newLabels);
}
return new LabelTranslator(extLabels);
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class ArffFormat method writeConfigFile.
private static void writeConfigFile(RegDataSet dataSet, File arffFile) {
File configFile = new File(arffFile, ARFF_CONFIG_FILE_NAME);
Config config = new Config();
config.setInt(ARFF_CONFIG_NUM_DATA_POINTS, dataSet.getNumDataPoints());
config.setInt(ARFF_CONFIG_NUM_FEATURES, dataSet.getNumFeatures());
try {
config.store(configFile);
} catch (Exception e) {
e.printStackTrace();
}
}
Aggregations