Search in sources :

Example 1 with IndependentPredictor

use of edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor in project pyramid by cheng-li.

the class BRCalibration method calibrate.

private static void calibrate(Config config, Logger logger) throws Exception {
    logger.info("start training calibrators");
    DataSetType dataSetType;
    switch(config.getString("dataSetType")) {
        case "sparse_random":
            dataSetType = DataSetType.ML_CLF_SPARSE;
            break;
        case "sparse_sequential":
            dataSetType = DataSetType.ML_CLF_SEQ_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("unknown dataSetType");
    }
    MultiLabelClfDataSet train = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), dataSetType, true);
    MultiLabelClfDataSet cal = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.calibrationData"), dataSetType, true);
    MultiLabelClfDataSet valid = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.validData"), dataSetType, true);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "support").toFile());
    MultiLabelClassifier.ClassProbEstimator classProbEstimator = (MultiLabelClassifier.ClassProbEstimator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
    List<Integer> labelCalIndices = IntStream.range(0, cal.getNumDataPoints()).filter(i -> i % 2 == 0).boxed().collect(Collectors.toList());
    List<Integer> setCalIndices = IntStream.range(0, cal.getNumDataPoints()).filter(i -> i % 2 == 1).boxed().collect(Collectors.toList());
    MultiLabelClfDataSet labelCalData = DataSetUtil.sampleData(cal, labelCalIndices);
    MultiLabelClfDataSet setCalData = DataSetUtil.sampleData(cal, setCalIndices);
    logger.info("start training label calibrator");
    LabelCalibrator labelCalibrator = null;
    switch(config.getString("labelCalibrator")) {
        case "isotonic":
            labelCalibrator = new IsoLabelCalibrator(classProbEstimator, labelCalData, false);
            break;
        case "identity":
            labelCalibrator = new IdentityLabelCalibrator();
            break;
    }
    logger.info("finish training label calibrator");
    logger.info("start training set calibrator");
    List<PredictionFeatureExtractor> extractors = new ArrayList<>();
    if (config.getBoolean("brProb")) {
        // todo order matters; the first one will be used by iso, card iso
        extractors.add(new BRProbFeatureExtractor());
    }
    if (config.getBoolean("expectedF1")) {
        extractors.add(new ExpectedF1FeatureExtractor());
    }
    if (config.getBoolean("expectedPrecision")) {
        extractors.add(new ExpectedPrecisionFeatureExtractor());
    }
    if (config.getBoolean("expectedRecall")) {
        extractors.add(new ExpectedRecallFeatureExtractor());
    }
    if (config.getBoolean("setPrior")) {
        extractors.add(new PriorFeatureExtractor(train));
    }
    if (config.getBoolean("card")) {
        extractors.add(new CardFeatureExtractor());
    }
    if (config.getBoolean("encodeLabel")) {
        extractors.add(new LabelBinaryFeatureExtractor(classProbEstimator.getNumClasses(), train.getLabelTranslator()));
    }
    if (config.getBoolean("useInitialFeatures")) {
        Set<String> prefixes = new HashSet<>(config.getStrings("featureFieldPrefix"));
        FeatureList featureList = train.getFeatureList();
        List<Integer> featureIds = new ArrayList<>();
        for (int j = 0; j < featureList.size(); j++) {
            Feature feature = featureList.get(j);
            if (feature instanceof CategoricalFeature) {
                if (matchPrefixes(((CategoricalFeature) feature).getVariableName(), prefixes)) {
                    featureIds.add(j);
                }
            } else {
                if (!(feature instanceof Ngram)) {
                    if (matchPrefixes(feature.getName(), prefixes)) {
                        featureIds.add(j);
                    }
                }
            }
        }
        extractors.add(new InstanceFeatureExtractor(featureIds, train.getFeatureList()));
    }
    PredictionFeatureExtractor predictionFeatureExtractor = new CombinedPredictionFeatureExtractor(extractors);
    CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
    CalibrationDataGenerator.TrainData caliTrainingData;
    CalibrationDataGenerator.TrainData caliValidData;
    caliTrainingData = calibrationDataGenerator.createCaliTrainingData(setCalData, classProbEstimator, config.getInt("numCandidates"), config.getString("calibrate.target"), support, 10);
    caliValidData = calibrationDataGenerator.createCaliTrainingData(valid, classProbEstimator, config.getInt("numCandidates"), config.getString("calibrate.target"), support, 10);
    RegDataSet calibratorTrainData = caliTrainingData.regDataSet;
    double[] weights = caliTrainingData.instanceWeights;
    VectorCalibrator setCalibrator = null;
    switch(config.getString("setCalibrator")) {
        case "cardinality_isotonic":
            setCalibrator = new VectorCardIsoSetCalibrator(calibratorTrainData, 0, 2, false);
            break;
        case "reranker":
            RerankerTrainer rerankerTrainer = RerankerTrainer.newBuilder().numCandidates(config.getInt("numCandidates")).numLeaves(config.getInt("numLeaves")).monotonicityType("weak").build();
            setCalibrator = rerankerTrainer.trainWithSigmoid(calibratorTrainData, weights, classProbEstimator, predictionFeatureExtractor, labelCalibrator, caliValidData.regDataSet);
            break;
        case "isotonic":
            setCalibrator = new VectorIsoSetCalibrator(calibratorTrainData, 0, false);
            break;
        case "identity":
            setCalibrator = new VectorIdentityCalibrator(0);
            break;
        case "zero":
            setCalibrator = new ZeroCalibrator();
            break;
        default:
            throw new IllegalArgumentException("illegal setCalibrator");
    }
    logger.info("finish training set calibrator");
    Serialization.serialize(labelCalibrator, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "label_calibrator").toFile());
    Serialization.serialize(setCalibrator, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "set_calibrator").toFile());
    Serialization.serialize(predictionFeatureExtractor, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "prediction_feature_extractor").toFile());
    logger.info("finish training calibrators");
    MultiLabelClassifier classifier = null;
    switch(config.getString("predict.mode")) {
        case "independent":
            classifier = new IndependentPredictor(classProbEstimator, labelCalibrator);
            break;
        case "support":
            classifier = new edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor(classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, support);
            break;
        case "reranker":
            Reranker reranker = (Reranker) setCalibrator;
            reranker.setMinPredictionSize(config.getInt("predict.minSize"));
            reranker.setMaxPredictionSize(config.getInt("predict.maxSize"));
            classifier = reranker;
            break;
        default:
            throw new IllegalArgumentException("illegal predict.mode");
    }
    MultiLabel[] predictions = classifier.predict(cal);
    MultiLabel[] predictions_valid = classifier.predict(valid);
    if (true) {
        logger.info("calibration performance on " + config.getString("input.calibrationFolder") + " set");
        List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, cal.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, cal.getRow(i), predictions[i], cal.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
        eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
    }
    logger.info("classification performance on " + config.getString("input.validFolder") + " set");
    logger.info(new MLMeasures(valid.getNumClasses(), valid.getMultiLabels(), predictions_valid).toString());
    if (true) {
        logger.info("calibration performance on " + config.getString("input.validFolder") + " set");
        List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, valid.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, valid.getRow(i), predictions_valid[i], valid.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
        eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
    }
}
Also used : IntStream(java.util.stream.IntStream) edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) SimpleFormatter(java.util.logging.SimpleFormatter) CalibrationEval(edu.neu.ccs.pyramid.eval.CalibrationEval) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) FileHandler(java.util.logging.FileHandler) Ngram(edu.neu.ccs.pyramid.feature.Ngram) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Serializable(java.io.Serializable) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Stream(java.util.stream.Stream) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) CategoricalFeature(edu.neu.ccs.pyramid.feature.CategoricalFeature) ArrayList(java.util.ArrayList) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) Feature(edu.neu.ccs.pyramid.feature.Feature) CategoricalFeature(edu.neu.ccs.pyramid.feature.CategoricalFeature) CategoricalFeature(edu.neu.ccs.pyramid.feature.CategoricalFeature) ArrayList(java.util.ArrayList) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) HashSet(java.util.HashSet) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) Ngram(edu.neu.ccs.pyramid.feature.Ngram) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList)

Example 2 with IndependentPredictor

use of edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor in project pyramid by cheng-li.

the class BRPrediction method report.

private static void report(Config config, String dataPath, Logger logger) throws Exception {
    DataSetType dataSetType;
    switch(config.getString("dataSetType")) {
        case "sparse_random":
            dataSetType = DataSetType.ML_CLF_SPARSE;
            break;
        case "sparse_sequential":
            dataSetType = DataSetType.ML_CLF_SEQ_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("unknown dataSetType");
    }
    MultiLabelClfDataSet test = TRECFormat.loadMultiLabelClfDataSet(dataPath, dataSetType, true);
    MultiLabelClassifier.ClassProbEstimator classProbEstimator = (MultiLabelClassifier.ClassProbEstimator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
    LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "label_calibrator").toFile());
    VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "set_calibrator").toFile());
    PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "prediction_feature_extractor").toFile());
    File testDataFile = new File(dataPath);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "support").toFile());
    String reportFolder = Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "predictions", testDataFile.getName() + "_reports").toString();
    MultiLabelClassifier classifier = null;
    switch(config.getString("predict.mode")) {
        case "independent":
            classifier = new IndependentPredictor(classProbEstimator, labelCalibrator);
            break;
        case "support":
            classifier = new edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor(classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, support);
            break;
        case "reranker":
            Reranker reranker = (Reranker) setCalibrator;
            reranker.setMinPredictionSize(config.getInt("predict.minSize"));
            reranker.setMaxPredictionSize(config.getInt("predict.maxSize"));
            classifier = reranker;
            break;
        default:
            throw new IllegalArgumentException("illegal predict.mode");
    }
    MultiLabel[] predictions = classifier.predict(test);
    logger.info("classification performance on dataset " + testDataFile.getName());
    MLMeasures mlMeasures = new MLMeasures(test.getNumClasses(), test.getMultiLabels(), predictions);
    mlMeasures.getMacroAverage().updateAveragePrecision(classProbEstimator, test);
    logger.info(mlMeasures.toString());
    CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
    if (true) {
        logger.info("calibration performance on dataset " + testDataFile.getName());
        List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, test.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, test.getRow(i), predictions[i], test.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
        BRCalibration.eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
    }
    MultiLabelClassifier fClassifier = classifier;
    boolean simpleCSV = true;
    if (simpleCSV) {
        File csv = Paths.get(reportFolder, "report.csv").toFile();
        csv.getParentFile().mkdirs();
        if (csv.exists()) {
            csv.delete();
        }
        StringBuilder sb = new StringBuilder();
        sb.append("doc_id").append("\t").append("predictions").append("\t").append("prediction_type").append("\t").append("confidence").append("\t").append("truth").append("\t").append("ground_truth").append("\t").append("precision").append("\t").append("recall").append("\t").append("F1").append("\n");
        FileUtils.writeStringToFile(csv, sb.toString());
        List<Integer> list = IntStream.range(0, test.getNumDataPoints()).boxed().collect(Collectors.toList());
        ParallelStringMapper<Integer> mapper = (list1, i) -> simplePredictionAnalysisCalibrated(classProbEstimator, labelCalibrator, setCalibrator, test, i, fClassifier, predictionFeatureExtractor);
        ParallelFileWriter.mapToString(mapper, list, csv, 100, true);
    }
    boolean topSets = true;
    if (topSets) {
        File csv = Paths.get(reportFolder, "top_sets.csv").toFile();
        csv.getParentFile().mkdirs();
        List<Integer> list = IntStream.range(0, test.getNumDataPoints()).boxed().collect(Collectors.toList());
        ParallelStringMapper<Integer> mapper = (list1, i) -> topKSets(config, classProbEstimator, labelCalibrator, setCalibrator, test, i, fClassifier, predictionFeatureExtractor);
        ParallelFileWriter.mapToString(mapper, list, csv, 100);
    }
    boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
    if (rulesToJson) {
        logger.info("start writing rules to json");
        int ruleLimit = config.getInt("report.rule.limit");
        int numDocsPerFile = config.getInt("report.numDocsPerFile");
        int numFiles = (int) Math.ceil((double) test.getNumDataPoints() / numDocsPerFile);
        double probThreshold = config.getDouble("report.classProbThreshold");
        int labelSetLimit = config.getInt("report.labelSetLimit");
        IntStream.range(0, numFiles).forEach(i -> {
            int start = i * numDocsPerFile;
            int end = start + numDocsPerFile;
            List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, test.getNumDataPoints())).parallel().mapToObj(a -> BRInspector.analyzePrediction(classProbEstimator, labelCalibrator, setCalibrator, test, fClassifier, predictionFeatureExtractor, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
            ObjectMapper mapper = new ObjectMapper();
            File jsonFile = Paths.get(reportFolder, "report_" + (i + 1) + ".json").toFile();
            try {
                mapper.writeValue(jsonFile, partition);
            } catch (IOException e) {
                e.printStackTrace();
            }
            logger.info("progress = " + Progress.percentage(i + 1, numFiles));
        });
        logger.info("finish writing rules to json");
    }
    boolean individualPerformance = true;
    if (individualPerformance) {
        logger.info("start writing individual label performance to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(reportFolder, "individual_performance.json").toFile(), mlMeasures.getMacroAverage());
        logger.info("finish writing individual label performance to json");
    }
    boolean dataConfigToJson = true;
    if (dataConfigToJson) {
        logger.info("start writing data config to json");
        File dataConfigFile = Paths.get(dataPath, "data_config.json").toFile();
        if (dataConfigFile.exists()) {
            FileUtils.copyFileToDirectory(dataConfigFile, new File(reportFolder));
        }
        logger.info("finish writing data config to json");
    }
    boolean dataInfoToJson = true;
    if (dataInfoToJson) {
        logger.info("start writing data info to json");
        Set<String> modelLabels = IntStream.range(0, classifier.getNumClasses()).mapToObj(i -> classProbEstimator.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        Set<String> dataSetLabels = DataSetUtil.gatherLabels(test).stream().map(i -> test.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        JsonGenerator jsonGenerator = new JsonFactory().createGenerator(Paths.get(reportFolder, "data_info.json").toFile(), JsonEncoding.UTF8);
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("dataSet", testDataFile.getName());
        jsonGenerator.writeNumberField("numClassesInModel", classifier.getNumClasses());
        jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", test.getNumClasses());
        Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
        Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
        jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
        jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
        for (String label : dataNotModelLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
        for (String label : modelNotDataLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeNumberField("labelCardinality", test.labelCardinality());
        jsonGenerator.writeEndObject();
        jsonGenerator.close();
        logger.info("finish writing data info to json");
    }
    boolean performanceToJson = true;
    if (performanceToJson) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(Paths.get(reportFolder, "performance.json").toFile(), mlMeasures);
    }
    if (config.getBoolean("report.produceHTML")) {
        logger.info("start producing html files");
        Config savedApp1Config = new Config(Paths.get(config.getString("output.dir"), "meta_data", "saved_config_app1").toFile());
        List<String> hosts = savedApp1Config.getStrings("index.hosts");
        List<Integer> ports = savedApp1Config.getIntegers("index.ports");
        // todo make it better
        if (savedApp1Config.getString("index.clientType").equals("node")) {
            hosts = new ArrayList<>();
            for (int port : ports) {
                hosts.add("localhost");
            }
            // default setting
            hosts.add("localhost");
            ports.add(9200);
        }
        try (Visualizer visualizer = new Visualizer(logger, hosts, ports)) {
            visualizer.produceHtml(new File(reportFolder));
            logger.info("finish producing html files");
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) Visualizer(edu.neu.ccs.pyramid.visualization.Visualizer) edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) BRInspector(edu.neu.ccs.pyramid.multilabel_classification.BRInspector) ArrayList(java.util.ArrayList) FileHandler(java.util.logging.FileHandler) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) FMeasure(edu.neu.ccs.pyramid.eval.FMeasure) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Precision(edu.neu.ccs.pyramid.eval.Precision) Recall(edu.neu.ccs.pyramid.eval.Recall) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) Comparator(java.util.Comparator) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) Config(edu.neu.ccs.pyramid.configuration.Config) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) ArrayList(java.util.ArrayList) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) IOException(java.io.IOException) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) Visualizer(edu.neu.ccs.pyramid.visualization.Visualizer) File(java.io.File)

Example 3 with IndependentPredictor

use of edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor in project pyramid by cheng-li.

the class BRRerank method report.

private static void report(Config config) throws Exception {
    MultiLabelClfDataSet dataset = TRECFormat.loadMultiLabelClfDataSet(Paths.get(config.getString("dataPath"), "test").toFile(), DataSetType.ML_CLF_SPARSE, true);
    CBM cbm = (CBM) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "model"));
    cbm.setAllowEmpty(true);
    MultiLabelClassifier classifier = null;
    LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "label_calibrator"));
    VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "set_calibrator"));
    PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "calibration_feature_extractor"));
    CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
    switch(config.getString("predictMode")) {
        case "independent":
            classifier = new IndependentPredictor(cbm, labelCalibrator);
            break;
        case "rerank":
            classifier = (Reranker) setCalibrator;
            break;
        default:
            throw new IllegalArgumentException("illegal predict.mode");
    }
    MultiLabel[] predictions = classifier.predict(dataset);
    List<CalibInfo> confidenceScores = IntStream.range(0, dataset.getNumDataPoints()).parallel().boxed().map(i -> {
        CalibrationDataGenerator.CalibrationInstance predictionInstance = calibrationDataGenerator.createInstance(cbm, dataset.getRow(i), predictions[i], dataset.getMultiLabels()[i], "accuracy");
        double calibrated = setCalibrator.calibrate(predictionInstance.vector);
        CalibInfo calibInfo = new CalibInfo();
        calibInfo.uncalibrated = predictionInstance.vector.get(0);
        calibInfo.calibrated = calibrated;
        calibInfo.accuracy = predictionInstance.correctness;
        return calibInfo;
    }).collect(Collectors.toList());
    StringBuilder stringBuilder = new StringBuilder();
    stringBuilder.append("set_prediction").append("\t").append("uncalibrated_confidence").append("\t").append("calibrated_confidence").append("\t").append("ground_truth").append("\t").append("set_accuracy").append("\n");
    for (int i = 0; i < dataset.getNumDataPoints(); i++) {
        stringBuilder.append(predictions[i]).append("\t").append(confidenceScores.get(i).uncalibrated).append("\t").append(confidenceScores.get(i).calibrated).append("\t").append(dataset.getMultiLabels()[i]).append("\t").append(predictions[i].equals(dataset.getMultiLabels()[i]) ? 1 : 0).append("\n");
    }
    FileUtils.writeStringToFile(Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toFile(), stringBuilder.toString());
    System.out.println("set predictions and confidence scores are saved to " + Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toString() + "\n");
}
Also used : IntStream(java.util.stream.IntStream) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) CalibrationEval(edu.neu.ccs.pyramid.eval.CalibrationEval) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor)

Aggregations

edu.neu.ccs.pyramid.calibration (edu.neu.ccs.pyramid.calibration)3 Config (edu.neu.ccs.pyramid.configuration.Config)3 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)3 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)3 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)3 IndependentPredictor (edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor)3 Paths (java.nio.file.Paths)3 ArrayList (java.util.ArrayList)3 List (java.util.List)3 Collectors (java.util.stream.Collectors)3 IntStream (java.util.stream.IntStream)3 FileUtils (org.apache.commons.io.FileUtils)3 CalibrationEval (edu.neu.ccs.pyramid.eval.CalibrationEval)2 edu.neu.ccs.pyramid.util (edu.neu.ccs.pyramid.util)2 File (java.io.File)2 Set (java.util.Set)2 FileHandler (java.util.logging.FileHandler)2 Logger (java.util.logging.Logger)2 SimpleFormatter (java.util.logging.SimpleFormatter)2 JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)1