Search in sources :

Example 76 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet, String name) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on " + name + " set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    System.out.println("time spent on prediction = " + stopWatch);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println(name + " performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, name + "_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println(name + " performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, name + "_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    boolean individualPerformance = true;
    if (individualPerformance) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(output, name + "_predictions", "instance_hamming_loss_optimal", "individual_performance.json").toFile(), mlMeasures.getMacroAverage());
    }
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) AveragePrecision(edu.neu.ccs.pyramid.eval.AveragePrecision) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) MAP(edu.neu.ccs.pyramid.eval.MAP) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LogisticRegressionInspector(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegressionInspector) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) LogLikelihood(edu.neu.ccs.pyramid.eval.LogLikelihood) FileWriter(java.io.FileWriter) StopWatch(org.apache.commons.lang3.time.StopWatch) BufferedWriter(java.io.BufferedWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 77 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportAccPrediction.

private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet, String name) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on " + name + " set with the instance set accuracy optimal predictor");
    String output = config.getString("output.dir");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = accPredictor.predict(dataSet);
    System.out.println("time spent on prediction = " + stopWatch);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println(name + " performance with the instance set accuracy optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, name + "_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println(name + " performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, name + "_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    boolean individualPerformance = true;
    if (individualPerformance) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(output, name + "_predictions", "instance_accuracy_optimal", "individual_performance.json").toFile(), mlMeasures.getMacroAverage());
    }
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) AveragePrecision(edu.neu.ccs.pyramid.eval.AveragePrecision) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) MAP(edu.neu.ccs.pyramid.eval.MAP) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LogisticRegressionInspector(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegressionInspector) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) LogLikelihood(edu.neu.ccs.pyramid.eval.LogLikelihood) FileWriter(java.io.FileWriter) StopWatch(org.apache.commons.lang3.time.StopWatch) BufferedWriter(java.io.BufferedWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 78 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportGeneral.

private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet, String name) throws Exception {
    System.out.println("============================================================");
    System.out.println("computing other predictor-independent metrics");
    System.out.println("label averaged MAP");
    System.out.println(MAP.map(cbm, dataSet));
    System.out.println("instance averaged MAP");
    System.out.println(MAP.instanceMAP(cbm, dataSet));
    System.out.println("global AP truncated at 30");
    System.out.println(AveragePrecision.globalAveragePrecisionTruncated(cbm, dataSet, 30));
    String output = config.getString("output.dir");
    File labelProbFile = Paths.get(output, name + "_predictions", "label_probabilities.txt").toFile();
    double labelProbThreshold = config.getDouble("report.labelProbThreshold");
    try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
            br.newLine();
        }
    }
    System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
    List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
    // Here we do not use approximation
    double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
    double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
    File logLikelihoodFile = Paths.get(output, name + "_predictions", "ground_truth_log_likelihood.txt").toFile();
    FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
    System.out.println("individual log likelihood of the " + name + " ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
    System.out.println("average log likelihood of the " + name + " ground truth label sets = " + average);
    if (!unobservedLabels.isEmpty() && name.equals("test")) {
        System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
        System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
        System.out.println(ListUtil.toSimpleString(unobservedLabels));
    }
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) AveragePrecision(edu.neu.ccs.pyramid.eval.AveragePrecision) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) MAP(edu.neu.ccs.pyramid.eval.MAP) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LogisticRegressionInspector(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegressionInspector) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) LogLikelihood(edu.neu.ccs.pyramid.eval.LogLikelihood) FileWriter(java.io.FileWriter) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 79 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportF1Prediction.

private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet, String name) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on " + name + " set with the instance F1 optimal predictor");
    String output = config.getString("output.dir");
    PluginF1 pluginF1 = new PluginF1(cbm);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
    pluginF1.setSupport(support);
    pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabel[] predictions = pluginF1.predict(dataSet);
    System.out.println("time spent on prediction = " + stopWatch);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println(name + " performance with the instance F1 optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, name + "_predictions", "instance_f1_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println(name + " performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, name + "_predictions", "instance_f1_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    boolean individualPerformance = true;
    if (individualPerformance) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(output, name + "_predictions", "instance_f1_optimal", "individual_performance.json").toFile(), mlMeasures.getMacroAverage());
    }
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) AveragePrecision(edu.neu.ccs.pyramid.eval.AveragePrecision) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) MAP(edu.neu.ccs.pyramid.eval.MAP) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LogisticRegressionInspector(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegressionInspector) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) LogLikelihood(edu.neu.ccs.pyramid.eval.LogLikelihood) FileWriter(java.io.FileWriter) StopWatch(org.apache.commons.lang3.time.StopWatch) BufferedWriter(java.io.BufferedWriter) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 80 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class Calibration method main.

public static void main(String[] args) throws Exception {
    Config config = new Config(args[0]);
    main(config);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)119 File (java.io.File)68 Collectors (java.util.stream.Collectors)40 FileUtils (org.apache.commons.io.FileUtils)40 Paths (java.nio.file.Paths)39 IntStream (java.util.stream.IntStream)37 Pair (edu.neu.ccs.pyramid.util.Pair)36 Serialization (edu.neu.ccs.pyramid.util.Serialization)35 StopWatch (org.apache.commons.lang3.time.StopWatch)34 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)33 BufferedWriter (java.io.BufferedWriter)32 FileWriter (java.io.FileWriter)32 java.util (java.util)32 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)31 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)29 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)28 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)26 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)25 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)25 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)22