use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMEN method reportGeneral.
private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("computing other predictor-independent metrics");
String output = config.getString("output.dir");
File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
double labelProbThreshold = config.getDouble("report.labelProbThreshold");
try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
br.newLine();
}
}
System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
// Here we do not use approximation
double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
System.out.println("average log likelihood of the test ground truth label sets = " + average);
if (!unobservedLabels.isEmpty()) {
System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
System.out.println(ListUtil.toSimpleString(unobservedLabels));
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMEN method reportAccPrediction.
private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
String output = config.getString("output.dir");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = accPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance set accuracy optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method reportHammingPrediction.
private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
String output = config.getString("output.dir");
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = marginalPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance Hamming loss optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMLR method reportAccPrediction.
private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
String output = config.getString("output.dir");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = accPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance set accuracy optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App1 method ngramSelection.
private static void ngramSelection(Config config, MultiLabelIndex index, String docFilter, Logger logger) throws Exception {
logger.info("start ngram selection");
File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
FeatureLoader.MatchScoreType matchScoreType;
String matchScoreTypeString = config.getString("train.feature.ngram.matchScoreType");
String[] indexIds = getDocsForSplitFromQuery(index, config.getString("train.splitQuery"));
IdTranslator idTranslator = loadIdTranslator(indexIds);
LabelTranslator labelTranslator = (LabelTranslator) Serialization.deserialize(new File(metaDataFolder, "label_translator.ser"));
switch(matchScoreTypeString) {
case "es_original":
matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
break;
case "binary":
matchScoreType = FeatureLoader.MatchScoreType.BINARY;
break;
case "frequency":
matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
break;
case "tfifl":
matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
break;
default:
throw new IllegalArgumentException("unknown ngramMatchScoreType");
}
double[][] labels = loadLabels(config, index, idTranslator, labelTranslator);
int numLabels = labels.length;
int toKeep = config.getInt("train.feature.ngram.selectPerLabel");
List<BoundedBlockPriorityQueue<Pair<Ngram, Double>>> queues = new ArrayList<>();
Comparator<Pair<Ngram, Double>> comparator = Comparator.comparing(p -> p.getSecond());
for (int l = 0; l < numLabels; l++) {
queues.add(new BoundedBlockPriorityQueue<>(toKeep, comparator));
}
FeatureList featureList = (FeatureList) Serialization.deserialize(new File(metaDataFolder, "feature_list.ser"));
featureList.getAll().stream().parallel().filter(feature -> feature instanceof Ngram).map(feature -> (Ngram) feature).filter(ngram -> ngram.getN() > 1).forEach(ngram -> {
double[] scores = StumpSelector.scores(index, labels, ngram, idTranslator, matchScoreType, docFilter);
for (int l = 0; l < numLabels; l++) {
queues.get(l).add(new Pair<>(ngram, scores[l]));
}
});
Set<Ngram> kept = new HashSet<>();
StringBuilder stringBuilder = new StringBuilder();
for (int l = 0; l < numLabels; l++) {
stringBuilder.append("-------------------------").append("\n");
stringBuilder.append(labelTranslator.toExtLabel(l)).append(":").append("\n");
BoundedBlockPriorityQueue<Pair<Ngram, Double>> queue = queues.get(l);
while (queue.size() > 0) {
Ngram ngram = queue.poll().getFirst();
kept.add(ngram);
stringBuilder.append(ngram.getNgram()).append(", ");
}
stringBuilder.append("\n");
}
File selectionFile = new File(metaDataFolder, "selected_ngrams.txt");
FileUtils.writeStringToFile(selectionFile, stringBuilder.toString());
logger.info("finish ngram selection");
logger.info("selected ngrams are written to " + selectionFile.getAbsolutePath());
// after feature selection, overwrite the feature_list.ser file; rename old files
FeatureList selectedFeatures = new FeatureList();
for (Feature feature : featureList.getAll()) {
if (!(feature instanceof Ngram)) {
selectedFeatures.add(feature);
}
if ((feature instanceof Ngram) && ((Ngram) feature).getN() == 1) {
selectedFeatures.add(feature);
}
if ((feature instanceof Ngram) && ((Ngram) feature).getN() > 1 && kept.contains(feature)) {
selectedFeatures.add(feature);
}
}
FileUtils.copyFile(new File(metaDataFolder, "feature_list.ser"), new File(metaDataFolder, "feature_list_all.ser"));
FileUtils.copyFile(new File(metaDataFolder, "feature_list.txt"), new File(metaDataFolder, "feature_list_all.txt"));
Serialization.serialize(selectedFeatures, new File(metaDataFolder, "feature_list.ser"));
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(metaDataFolder, "feature_list.txt")))) {
for (Feature feature : selectedFeatures.getAll()) {
bufferedWriter.write(feature.toString());
bufferedWriter.newLine();
}
}
}
Aggregations