Search in sources :

Example 1 with MultiLabelPredictionAnalysis

use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.

the class MLLogisticRegressionInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(MLLogisticRegression logisticRegression, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    predictionAnalysis.setProbForTrueLabels(Double.NaN);
    MultiLabel predictedLabels = logisticRegression.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    predictionAnalysis.setProbForPredictedLabels(Double.NaN);
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) ClassScoreCalculation(edu.neu.ccs.pyramid.regression.ClassScoreCalculation) ArrayList(java.util.ArrayList)

Example 2 with MultiLabelPredictionAnalysis

use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.

the class HMLGBInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(HMLGradientBoosting boosting, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
    MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels));
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)

Example 3 with MultiLabelPredictionAnalysis

use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.

the class BRPrediction method report.

private static void report(Config config, String dataPath, Logger logger) throws Exception {
    DataSetType dataSetType;
    switch(config.getString("dataSetType")) {
        case "sparse_random":
            dataSetType = DataSetType.ML_CLF_SPARSE;
            break;
        case "sparse_sequential":
            dataSetType = DataSetType.ML_CLF_SEQ_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("unknown dataSetType");
    }
    MultiLabelClfDataSet test = TRECFormat.loadMultiLabelClfDataSet(dataPath, dataSetType, true);
    MultiLabelClassifier.ClassProbEstimator classProbEstimator = (MultiLabelClassifier.ClassProbEstimator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
    LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "label_calibrator").toFile());
    VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "set_calibrator").toFile());
    PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "prediction_feature_extractor").toFile());
    File testDataFile = new File(dataPath);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "support").toFile());
    String reportFolder = Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "predictions", testDataFile.getName() + "_reports").toString();
    MultiLabelClassifier classifier = null;
    switch(config.getString("predict.mode")) {
        case "independent":
            classifier = new IndependentPredictor(classProbEstimator, labelCalibrator);
            break;
        case "support":
            classifier = new edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor(classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, support);
            break;
        case "reranker":
            Reranker reranker = (Reranker) setCalibrator;
            reranker.setMinPredictionSize(config.getInt("predict.minSize"));
            reranker.setMaxPredictionSize(config.getInt("predict.maxSize"));
            classifier = reranker;
            break;
        default:
            throw new IllegalArgumentException("illegal predict.mode");
    }
    MultiLabel[] predictions = classifier.predict(test);
    logger.info("classification performance on dataset " + testDataFile.getName());
    MLMeasures mlMeasures = new MLMeasures(test.getNumClasses(), test.getMultiLabels(), predictions);
    mlMeasures.getMacroAverage().updateAveragePrecision(classProbEstimator, test);
    logger.info(mlMeasures.toString());
    CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
    if (true) {
        logger.info("calibration performance on dataset " + testDataFile.getName());
        List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, test.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, test.getRow(i), predictions[i], test.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
        BRCalibration.eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
    }
    MultiLabelClassifier fClassifier = classifier;
    boolean simpleCSV = true;
    if (simpleCSV) {
        File csv = Paths.get(reportFolder, "report.csv").toFile();
        csv.getParentFile().mkdirs();
        if (csv.exists()) {
            csv.delete();
        }
        StringBuilder sb = new StringBuilder();
        sb.append("doc_id").append("\t").append("predictions").append("\t").append("prediction_type").append("\t").append("confidence").append("\t").append("truth").append("\t").append("ground_truth").append("\t").append("precision").append("\t").append("recall").append("\t").append("F1").append("\n");
        FileUtils.writeStringToFile(csv, sb.toString());
        List<Integer> list = IntStream.range(0, test.getNumDataPoints()).boxed().collect(Collectors.toList());
        ParallelStringMapper<Integer> mapper = (list1, i) -> simplePredictionAnalysisCalibrated(classProbEstimator, labelCalibrator, setCalibrator, test, i, fClassifier, predictionFeatureExtractor);
        ParallelFileWriter.mapToString(mapper, list, csv, 100, true);
    }
    boolean topSets = true;
    if (topSets) {
        File csv = Paths.get(reportFolder, "top_sets.csv").toFile();
        csv.getParentFile().mkdirs();
        List<Integer> list = IntStream.range(0, test.getNumDataPoints()).boxed().collect(Collectors.toList());
        ParallelStringMapper<Integer> mapper = (list1, i) -> topKSets(config, classProbEstimator, labelCalibrator, setCalibrator, test, i, fClassifier, predictionFeatureExtractor);
        ParallelFileWriter.mapToString(mapper, list, csv, 100);
    }
    boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
    if (rulesToJson) {
        logger.info("start writing rules to json");
        int ruleLimit = config.getInt("report.rule.limit");
        int numDocsPerFile = config.getInt("report.numDocsPerFile");
        int numFiles = (int) Math.ceil((double) test.getNumDataPoints() / numDocsPerFile);
        double probThreshold = config.getDouble("report.classProbThreshold");
        int labelSetLimit = config.getInt("report.labelSetLimit");
        IntStream.range(0, numFiles).forEach(i -> {
            int start = i * numDocsPerFile;
            int end = start + numDocsPerFile;
            List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, test.getNumDataPoints())).parallel().mapToObj(a -> BRInspector.analyzePrediction(classProbEstimator, labelCalibrator, setCalibrator, test, fClassifier, predictionFeatureExtractor, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
            ObjectMapper mapper = new ObjectMapper();
            File jsonFile = Paths.get(reportFolder, "report_" + (i + 1) + ".json").toFile();
            try {
                mapper.writeValue(jsonFile, partition);
            } catch (IOException e) {
                e.printStackTrace();
            }
            logger.info("progress = " + Progress.percentage(i + 1, numFiles));
        });
        logger.info("finish writing rules to json");
    }
    boolean individualPerformance = true;
    if (individualPerformance) {
        logger.info("start writing individual label performance to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(reportFolder, "individual_performance.json").toFile(), mlMeasures.getMacroAverage());
        logger.info("finish writing individual label performance to json");
    }
    boolean dataConfigToJson = true;
    if (dataConfigToJson) {
        logger.info("start writing data config to json");
        File dataConfigFile = Paths.get(dataPath, "data_config.json").toFile();
        if (dataConfigFile.exists()) {
            FileUtils.copyFileToDirectory(dataConfigFile, new File(reportFolder));
        }
        logger.info("finish writing data config to json");
    }
    boolean dataInfoToJson = true;
    if (dataInfoToJson) {
        logger.info("start writing data info to json");
        Set<String> modelLabels = IntStream.range(0, classifier.getNumClasses()).mapToObj(i -> classProbEstimator.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        Set<String> dataSetLabels = DataSetUtil.gatherLabels(test).stream().map(i -> test.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        JsonGenerator jsonGenerator = new JsonFactory().createGenerator(Paths.get(reportFolder, "data_info.json").toFile(), JsonEncoding.UTF8);
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("dataSet", testDataFile.getName());
        jsonGenerator.writeNumberField("numClassesInModel", classifier.getNumClasses());
        jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", test.getNumClasses());
        Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
        Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
        jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
        jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
        for (String label : dataNotModelLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
        for (String label : modelNotDataLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeNumberField("labelCardinality", test.labelCardinality());
        jsonGenerator.writeEndObject();
        jsonGenerator.close();
        logger.info("finish writing data info to json");
    }
    boolean performanceToJson = true;
    if (performanceToJson) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(reportFolder, "performance.json").toFile(), mlMeasures);
    }
    if (config.getBoolean("report.produceHTML")) {
        logger.info("start producing html files");
        Config savedApp1Config = new Config(Paths.get(config.getString("output.dir"), "meta_data", "saved_config_app1").toFile());
        List<String> hosts = savedApp1Config.getStrings("index.hosts");
        List<Integer> ports = savedApp1Config.getIntegers("index.ports");
        // todo make it better
        if (savedApp1Config.getString("index.clientType").equals("node")) {
            hosts = new ArrayList<>();
            for (int port : ports) {
                hosts.add("localhost");
            }
            // default setting
            hosts.add("localhost");
            ports.add(9200);
        }
        try (Visualizer visualizer = new Visualizer(logger, hosts, ports)) {
            visualizer.produceHtml(new File(reportFolder));
            logger.info("finish producing html files");
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) Visualizer(edu.neu.ccs.pyramid.visualization.Visualizer) edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) BRInspector(edu.neu.ccs.pyramid.multilabel_classification.BRInspector) ArrayList(java.util.ArrayList) FileHandler(java.util.logging.FileHandler) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) FMeasure(edu.neu.ccs.pyramid.eval.FMeasure) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Precision(edu.neu.ccs.pyramid.eval.Precision) Recall(edu.neu.ccs.pyramid.eval.Recall) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) Comparator(java.util.Comparator) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) Config(edu.neu.ccs.pyramid.configuration.Config) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) ArrayList(java.util.ArrayList) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) IOException(java.io.IOException) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) Visualizer(edu.neu.ccs.pyramid.visualization.Visualizer) File(java.io.File)

Example 4 with MultiLabelPredictionAnalysis

use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.

the class App2 method report.

static void report(Config config, String dataName, Logger logger) throws Exception {
    logger.info("generating reports for data set " + dataName);
    String output = config.getString("output.folder");
    String modelName = "model_app3";
    File analysisFolder = new File(new File(output, "reports_app3"), dataName + "_reports");
    analysisFolder.mkdirs();
    FileUtils.cleanDirectory(analysisFolder);
    IMLGradientBoosting boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    String predictTarget = config.getString("predict.target");
    PluginPredictor<IMLGradientBoosting> pluginPredictorTmp = null;
    switch(predictTarget) {
        case "subsetAccuracy":
            pluginPredictorTmp = new SubsetAccPredictor(boosting);
            break;
        case "hammingLoss":
            pluginPredictorTmp = new HammingPredictor(boosting);
            break;
        case "instanceFMeasure":
            pluginPredictorTmp = new InstanceF1Predictor(boosting);
            break;
        case "macroFMeasure":
            TunedMarginalClassifier tunedMarginalClassifier = (TunedMarginalClassifier) Serialization.deserialize(new File(output, "predictor_macro_f"));
            pluginPredictorTmp = new MacroF1Predictor(boosting, tunedMarginalClassifier);
            break;
        default:
            throw new IllegalArgumentException("unknown prediction target measure " + predictTarget);
    }
    // just to make Lambda expressions happy
    final PluginPredictor<IMLGradientBoosting> pluginPredictor = pluginPredictorTmp;
    MultiLabelClfDataSet dataSet = loadData(config, dataName);
    MLMeasures mlMeasures = new MLMeasures(pluginPredictor, dataSet);
    mlMeasures.getMacroAverage().setLabelTranslator(boosting.getLabelTranslator());
    logger.info("performance on dataset " + dataName);
    logger.info(mlMeasures.toString());
    boolean simpleCSV = true;
    if (simpleCSV) {
        logger.info("start generating simple CSV report");
        double probThreshold = config.getDouble("report.classProbThreshold");
        File csv = new File(analysisFolder, "report.csv");
        List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> IMLGBInspector.simplePredictionAnalysis(boosting, pluginPredictor, dataSet, i, probThreshold)).collect(Collectors.toList());
        try (BufferedWriter bw = new BufferedWriter(new FileWriter(csv))) {
            for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
                String str = strs.get(i);
                bw.write(str);
            }
        }
        logger.info("finish generating simple CSV report");
    }
    boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
    if (rulesToJson) {
        logger.info("start writing rules to json");
        int ruleLimit = config.getInt("report.rule.limit");
        int numDocsPerFile = config.getInt("report.numDocsPerFile");
        int numFiles = (int) Math.ceil((double) dataSet.getNumDataPoints() / numDocsPerFile);
        double probThreshold = config.getDouble("report.classProbThreshold");
        int labelSetLimit = config.getInt("report.labelSetLimit");
        IntStream.range(0, numFiles).forEach(i -> {
            int start = i * numDocsPerFile;
            int end = start + numDocsPerFile;
            List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, dataSet.getNumDataPoints())).parallel().mapToObj(a -> IMLGBInspector.analyzePrediction(boosting, pluginPredictor, dataSet, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
            ObjectMapper mapper = new ObjectMapper();
            String file = "report_" + (i + 1) + ".json";
            try {
                mapper.writeValue(new File(analysisFolder, file), partition);
            } catch (IOException e) {
                e.printStackTrace();
            }
            logger.info("progress = " + Progress.percentage(i + 1, numFiles));
        });
        logger.info("finish writing rules to json");
    }
    boolean dataInfoToJson = true;
    if (dataInfoToJson) {
        logger.info("start writing data info to json");
        Set<String> modelLabels = IntStream.range(0, boosting.getNumClasses()).mapToObj(i -> boosting.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("dataSet", dataName);
        jsonGenerator.writeNumberField("numClassesInModel", boosting.getNumClasses());
        jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
        Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
        Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
        jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
        jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
        for (String label : dataNotModelLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
        for (String label : modelNotDataLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
        jsonGenerator.writeEndObject();
        jsonGenerator.close();
        logger.info("finish writing data info to json");
    }
    boolean modelConfigToJson = true;
    if (modelConfigToJson) {
        logger.info("start writing model config to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
        logger.info("finish writing model config to json");
    }
    boolean dataConfigToJson = true;
    if (dataConfigToJson) {
        logger.info("start writing data config to json");
        File dataConfigFile = Paths.get(config.getString("input.folder"), "data_sets", dataName, "data_config.json").toFile();
        if (dataConfigFile.exists()) {
            FileUtils.copyFileToDirectory(dataConfigFile, analysisFolder);
        }
        logger.info("finish writing data config to json");
    }
    boolean performanceToJson = true;
    if (performanceToJson) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
    }
    boolean individualPerformance = true;
    if (individualPerformance) {
        logger.info("start writing individual label performance to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
        logger.info("finish writing individual label performance to json");
    }
    logger.info("reports generated");
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) FileWriter(java.io.FileWriter) JsonFactory(com.fasterxml.jackson.core.JsonFactory) BufferedWriter(java.io.BufferedWriter) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) IOException(java.io.IOException) File(java.io.File)

Example 5 with MultiLabelPredictionAnalysis

use of edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis in project pyramid by cheng-li.

the class IMLGBInspector method analyzePrediction.

//todo  speed up
public static MultiLabelPredictionAnalysis analyzePrediction(IMLGradientBoosting boosting, PluginPredictor<IMLGradientBoosting> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
        predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
    }
    if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
        predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
    }
    MultiLabel predictedLabels = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
        predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
    }
    if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
        predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
    }
    double[] classProbs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < boosting.getNumClasses(); k++) {
        if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, ruleLimit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(boosting.predictClassProb(dataSet.getRow(dataPointIndex), label));
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = null;
    if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
        double[] labelSetProbs = boosting.predictAllAssignmentProbsWithConstraint(dataSet.getRow(dataPointIndex));
        labelSetRanking = IntStream.range(0, boosting.getAssignments().size()).mapToObj(i -> {
            MultiLabel multiLabel = boosting.getAssignments().get(i);
            double setProb = labelSetProbs[i];
            MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
            return labelSetProbInfo;
        }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    }
    if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
        labelSetRanking = new ArrayList<>();
        DynamicProgramming dp = new DynamicProgramming(classProbs);
        for (int c = 0; c < labelSetLimit; c++) {
            DynamicProgramming.Candidate candidate = dp.nextHighest();
            MultiLabel multiLabel = candidate.getMultiLabel();
            double setProb = candidate.getProbability();
            MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
            labelSetRanking.add(labelSetProbInfo);
        }
    }
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IntStream(java.util.stream.IntStream) java.util(java.util) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming)

Aggregations

MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)7 Vector (org.apache.mahout.math.Vector)5 ArrayList (java.util.ArrayList)4 Collectors (java.util.stream.Collectors)4 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)3 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)3 Feature (edu.neu.ccs.pyramid.feature.Feature)3 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)3 File (java.io.File)3 IntStream (java.util.stream.IntStream)3 JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)2 JsonFactory (com.fasterxml.jackson.core.JsonFactory)2 JsonGenerator (com.fasterxml.jackson.core.JsonGenerator)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 DynamicProgramming (edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming)2 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)2 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)2 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)2 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)2 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)2