Search in sources :

Example 1 with CTAT

use of edu.neu.ccs.pyramid.calibration.CTAT in project pyramid by cheng-li.

the class AppCTAT method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    Logger logger = Logger.getAnonymousLogger();
    String logFile = config.getString("output.log");
    FileHandler fileHandler = null;
    if (!logFile.isEmpty()) {
        new File(logFile).getParentFile().mkdirs();
        // todo should append?
        fileHandler = new FileHandler(logFile, true);
        Formatter formatter = new SimpleFormatter();
        fileHandler.setFormatter(formatter);
        logger.addHandler(fileHandler);
        logger.setUseParentHandlers(false);
    }
    logger.info(config.toString());
    File output = new File(config.getString("output.folder"));
    output.mkdirs();
    logger.info("start tuning CTAT ");
    String validReportPath = config.getString("validReportPath");
    String testReportPath = config.getString("testReportPath");
    Stream<Pair<Double, Double>> validStream = ReportUtils.getConfidenceCorrectness(validReportPath).stream();
    List<Pair<Double, Double>> testList = ReportUtils.getConfidenceCorrectness(testReportPath);
    CTAT.Summary summaryValid = CTAT.findThreshold(validStream, config.getDouble("CTAT.targetAccuracy"));
    double ctat = summaryValid.getConfidenceThreshold();
    double ctat_clipped = ctat;
    if (ctat_clipped > config.getDouble("CTAT.upperBound")) {
        ctat_clipped = config.getDouble("CTAT.upperBound");
    }
    if (ctat_clipped < config.getDouble("CTAT.lowerBound")) {
        ctat_clipped = config.getDouble("CTAT.lowerBound");
    }
    FileUtils.writeStringToFile(Paths.get(config.getString("output.folder"), config.getString("CTAT.name") + "_unclipped").toFile(), "" + ctat);
    FileUtils.writeStringToFile(Paths.get(config.getString("output.folder"), config.getString("CTAT.name") + "_clipped").toFile(), "" + ctat_clipped);
    CTAT.Summary summaryTest = CTAT.applyThreshold(testList.stream(), ctat);
    CTAT.Summary summaryTest_clipped = CTAT.applyThreshold(testList.stream(), ctat_clipped);
    logger.info("tuning CTAT is done");
    logger.info("*****************");
    logger.info("autocoding performance with unclipped CTAT " + summaryTest.getConfidenceThreshold());
    logger.info("autocoding percentage = " + summaryTest.getAutoCodingPercentage());
    logger.info("autocoding accuracy = " + summaryTest.getAutoCodingAccuracy());
    logger.info("number of autocoded documents = " + summaryTest.getNumAutoCoded());
    logger.info("number of correct autocoded documents = " + summaryTest.getNumCorrectAutoCoded());
    logger.info("*****************");
    logger.info("autocoding performance with clipped CTAT " + summaryTest_clipped.getConfidenceThreshold());
    logger.info("autocoding percentage = " + summaryTest_clipped.getAutoCodingPercentage());
    logger.info("autocoding accuracy = " + summaryTest_clipped.getAutoCodingAccuracy());
    logger.info("number of autocoded documents = " + summaryTest_clipped.getNumAutoCoded());
    logger.info("number of correct autocoded documents = " + summaryTest_clipped.getNumCorrectAutoCoded());
    if (fileHandler != null) {
        fileHandler.close();
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) SimpleFormatter(java.util.logging.SimpleFormatter) Formatter(java.util.logging.Formatter) SimpleFormatter(java.util.logging.SimpleFormatter) Logger(java.util.logging.Logger) FileHandler(java.util.logging.FileHandler) CTAT(edu.neu.ccs.pyramid.calibration.CTAT) File(java.io.File) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 2 with CTAT

use of edu.neu.ccs.pyramid.calibration.CTAT in project pyramid by cheng-li.

the class BRAutomation method showAutomationPerformance.

public static void showAutomationPerformance(Config config, String dataPath, Logger logger) throws Exception {
    double confidenceThreshold = Double.parseDouble(FileUtils.readFileToString(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "threshold", config.getString("threshold.name")).toFile()));
    File testDataFile = new File(dataPath);
    String reportFolder = Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "predictions", testDataFile.getName() + "_reports").toString();
    String reportPath = Paths.get(reportFolder, "report.csv").toString();
    if (config.getString("threshold.targetMetric").equals("accuracy")) {
        CTAT.Summary summary = CTAT.applyThreshold(ReportUtils.getConfidenceCorrectness(reportPath).stream(), confidenceThreshold);
        logger.info("autocoding performance with unclipped CTAT " + summary.getConfidenceThreshold());
        logger.info("autocoding percentage = " + summary.getAutoCodingPercentage());
        logger.info("autocoding accuracy = " + summary.getAutoCodingAccuracy());
        logger.info("number of autocoded documents = " + summary.getNumAutoCoded());
        logger.info("number of correct autocoded documents = " + summary.getNumCorrectAutoCoded());
        double confidenceThresholdClipped = Double.parseDouble(FileUtils.readFileToString(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "threshold", config.getString("threshold.name") + "_clipped").toFile()));
        CTAT.Summary summaryClipped = CTAT.applyThreshold(ReportUtils.getConfidenceCorrectness(reportPath).stream(), confidenceThresholdClipped);
        logger.info("autocoding performance with clipped CTAT " + summaryClipped.getConfidenceThreshold());
        logger.info("autocoding percentage = " + summaryClipped.getAutoCodingPercentage());
        logger.info("autocoding accuracy = " + summaryClipped.getAutoCodingAccuracy());
        logger.info("number of autocoded documents = " + summaryClipped.getNumAutoCoded());
        logger.info("number of correct autocoded documents = " + summaryClipped.getNumCorrectAutoCoded());
    }
    if (config.getString("threshold.targetMetric").equals("f1")) {
        CTFT.Summary summary = CTFT.applyThreshold(ReportUtils.getConfidenceF1(reportPath).stream(), confidenceThreshold);
        logger.info("autocoding performance with unclipped CTFT " + summary.getConfidenceThreshold());
        logger.info("autocoding percentage = " + summary.getAutoCodingPercentage());
        logger.info("autocoding accuracy = " + summary.getAutoCodingAccuracy());
        logger.info("autocoding F1 = " + summary.getAutoCodingF1());
        logger.info("number of autocoded documents = " + summary.getNumAutoCoded());
        logger.info("number of correct autocoded documents = " + summary.getNumCorrectAutoCoded());
        double confidenceThresholdClipped = Double.parseDouble(FileUtils.readFileToString(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "threshold", config.getString("threshold.name") + "_clipped").toFile()));
        CTFT.Summary summaryClipped = CTFT.applyThreshold(ReportUtils.getConfidenceF1(reportPath).stream(), confidenceThresholdClipped);
        logger.info("autocoding performance with clipped CTFT " + summaryClipped.getConfidenceThreshold());
        logger.info("autocoding percentage = " + summaryClipped.getAutoCodingPercentage());
        logger.info("autocoding accuracy = " + summaryClipped.getAutoCodingAccuracy());
        logger.info("autocoding F1 = " + summaryClipped.getAutoCodingF1());
        logger.info("number of autocoded documents = " + summaryClipped.getNumAutoCoded());
        logger.info("number of correct autocoded documents = " + summaryClipped.getNumCorrectAutoCoded());
    }
}
Also used : CTFT(edu.neu.ccs.pyramid.calibration.CTFT) CTAT(edu.neu.ccs.pyramid.calibration.CTAT) File(java.io.File)

Example 3 with CTAT

use of edu.neu.ccs.pyramid.calibration.CTAT in project pyramid by cheng-li.

the class AppEnsemble method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    Logger logger = Logger.getAnonymousLogger();
    String logFile = config.getString("output.log");
    FileHandler fileHandler = null;
    if (!logFile.isEmpty()) {
        new File(logFile).getParentFile().mkdirs();
        // todo should append?
        fileHandler = new FileHandler(logFile, true);
        Formatter formatter = new SimpleFormatter();
        fileHandler.setFormatter(formatter);
        logger.addHandler(fileHandler);
        logger.setUseParentHandlers(false);
    }
    logger.info(config.toString());
    File output = new File(config.getString("output.folder"));
    output.mkdirs();
    List<String> modelPaths = config.getStrings("modelPaths");
    List<String> modelNames = config.getStrings("modelNames");
    String ensembleName = config.getString("ensembleModelName");
    String testFolder = config.getString("testFolder");
    String validFolder = config.getString("validFolder");
    double targetValue = config.getDouble("threshold.targetValue");
    logger.info("start loading all reports and getting ground truth");
    List<Map<String, DocumentReport>> testlistMaps = new ArrayList<>();
    List<Map<String, DocumentReport>> validlistMaps = new ArrayList<>();
    Map<String, String> groundTruthTest;
    Map<String, String> groundTruthValid;
    String dataSetPath = modelPaths.get(0).split("model_predictions")[0] + "data_sets/";
    String testSetPath = dataSetPath + testFolder;
    String validSetPath = dataSetPath + validFolder;
    MultiLabelClfDataSet testSetModel0 = TRECFormat.loadMultiLabelClfDataSet(testSetPath, DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet validSetModel0 = TRECFormat.loadMultiLabelClfDataSet(validSetPath, DataSetType.ML_CLF_SPARSE, true);
    groundTruthTest = ReportUtils.getIDGroundTruth(testSetModel0);
    groundTruthValid = ReportUtils.getIDGroundTruth(validSetModel0);
    for (int i = 0; i < modelPaths.size(); i++) {
        Map<String, DocumentReport> testmap = loadReportCSV(Paths.get(modelPaths.get(i), "predictions", testFolder + "_reports", "report.csv").toString(), modelNames.get(i));
        testlistMaps.add(testmap);
        Map<String, DocumentReport> validmap = loadReportCSV(Paths.get(modelPaths.get(i), "predictions", validFolder + "_reports", "report.csv").toString(), modelNames.get(i));
        validlistMaps.add(validmap);
    }
    logger.info("finish loading all reports and getting ground truth");
    logger.info("start generating ensemble test report");
    LabelTranslator newLabelTranslatorTest = getLabelTranslatorEnsemble(config, testFolder);
    List<String> testDocIds = ReportUtils.getDocIds(Paths.get(modelPaths.get(0), "predictions", testFolder + "_reports", "report.csv").toString());
    generateReport(config, groundTruthTest, testlistMaps, ensembleName, testFolder, testDocIds, newLabelTranslatorTest);
    logger.info("ensemble test report generated");
    logger.info("start generating ensemble validation report");
    LabelTranslator newLabelTranslatorValid = getLabelTranslatorEnsemble(config, validFolder);
    List<String> validDocIds = ReportUtils.getDocIds(Paths.get(modelPaths.get(0), "predictions", validFolder + "_reports", "report.csv").toString());
    generateReport(config, groundTruthValid, validlistMaps, ensembleName, validFolder, validDocIds, newLabelTranslatorValid);
    logger.info("ensemble validation report generated");
    logger.info("classification performance on dataset " + testFolder);
    MlMeasureInfo measureInfo_test = getmlMeasureInfo(config, testSetModel0, testFolder, newLabelTranslatorTest);
    MLMeasures mlMeasures = new MLMeasures(measureInfo_test.numClasses, measureInfo_test.multiLabels, measureInfo_test.predictions);
    logger.info(mlMeasures.toString());
    if (config.getBoolean("tuneThreshold")) {
        logger.info("start tuning confidence threshold");
        Stream<Pair<Double, Double>> streamValid;
        double threshold = 1.1;
        if (config.getString("threshold.targetMetric").equals("accuracy")) {
            streamValid = ReportUtils.getConfidenceCorrectness(Paths.get(config.getString("output.folder"), "model_predictions", ensembleName, "predictions", validFolder + "_reports", "report.csv").toString()).stream();
            CTAT.Summary validSummary = CTAT.findThreshold(streamValid, targetValue);
            threshold = validSummary.getConfidenceThreshold();
        }
        if (config.getString("threshold.targetMetric").equals("f1")) {
            streamValid = ReportUtils.getConfidenceF1(Paths.get(config.getString("output.folder"), "model_predictions", ensembleName, "predictions", validFolder + "_reports", "report.csv").toString()).stream();
            CTFT.Summary summary_valid = CTFT.findThreshold(streamValid, targetValue);
            threshold = summary_valid.getConfidenceThreshold();
        }
        FileUtils.writeStringToFile(Paths.get(config.getString("output.folder"), "model_predictions", ensembleName, "models", "threshold", config.getString("threshold.name")).toFile(), "" + threshold);
        double confidenceThresholdClipped = CTAT.clip(threshold, config.getDouble("threshold.lowerBound"), config.getDouble("threshold.upperBound"));
        FileUtils.writeStringToFile(Paths.get(config.getString("output.folder"), "model_predictions", ensembleName, "models", "threshold", config.getString("threshold.name") + "_clipped").toFile(), "" + confidenceThresholdClipped);
        logger.info("tuning threshold is done");
        List<Pair<Double, Double>> testStream;
        if (config.getString("threshold.targetMetric").equals("accuracy")) {
            testStream = ReportUtils.getConfidenceCorrectness(Paths.get(config.getString("output.folder"), "model_predictions", ensembleName, "predictions", testFolder + "_reports", "report.csv").toString());
            CTAT.Summary testSummary_unclipped = CTAT.applyThreshold(testStream.stream(), threshold);
            CTAT.Summary testSummary_clipped = CTAT.applyThreshold(testStream.stream(), confidenceThresholdClipped);
            logger.info("*****************");
            logger.info("autocoding performance with unclipped CTAT " + testSummary_unclipped.getConfidenceThreshold());
            logger.info("autocoding percentage = " + testSummary_unclipped.getAutoCodingPercentage());
            logger.info("autocoding accuracy = " + testSummary_unclipped.getAutoCodingAccuracy());
            logger.info("number of autocoded documents = " + testSummary_unclipped.getNumAutoCoded());
            logger.info("number of correct autocoded documents = " + testSummary_unclipped.getNumCorrectAutoCoded());
            logger.info("*****************");
            logger.info("autocoding performance with clipped CTAT " + testSummary_clipped.getConfidenceThreshold());
            logger.info("autocoding percentage = " + testSummary_clipped.getAutoCodingPercentage());
            logger.info("autocoding accuracy = " + testSummary_clipped.getAutoCodingAccuracy());
            logger.info("number of autocoded documents = " + testSummary_clipped.getNumAutoCoded());
            logger.info("number of correct autocoded documents = " + testSummary_clipped.getNumCorrectAutoCoded());
        }
        if (config.getString("threshold.targetMetric").equals("f1")) {
            testStream = ReportUtils.getConfidenceF1(Paths.get(config.getString("output.folder"), "model_predictions", ensembleName, "predictions", testFolder + "_reports", "report.csv").toString());
            CTFT.Summary summary_test = CTFT.applyThreshold(testStream.stream(), threshold);
            CTFT.Summary summary_test_clipped = CTFT.applyThreshold(testStream.stream(), confidenceThresholdClipped);
            logger.info("*****************");
            logger.info("autocoding performance with unclipped CTFT " + summary_test.getConfidenceThreshold());
            logger.info("autocoding percentage = " + summary_test.getAutoCodingPercentage());
            logger.info("autocoding accuracy = " + summary_test.getAutoCodingAccuracy());
            logger.info("autocoding F1 = " + summary_test.getAutoCodingF1());
            logger.info("number of autocoded documents = " + summary_test.getNumAutoCoded());
            logger.info("number of correct autocoded documents = " + summary_test.getNumCorrectAutoCoded());
            logger.info("*****************");
            logger.info("autocoding performance with clipped CTFT " + summary_test_clipped.getConfidenceThreshold());
            logger.info("autocoding percentage = " + summary_test_clipped.getAutoCodingPercentage());
            logger.info("autocoding accuracy = " + summary_test_clipped.getAutoCodingAccuracy());
            logger.info("autocoding F1 = " + summary_test_clipped.getAutoCodingF1());
            logger.info("number of autocoded documents = " + summary_test_clipped.getNumAutoCoded());
            logger.info("number of correct autocoded documents = " + summary_test_clipped.getNumCorrectAutoCoded());
        }
    }
    if (fileHandler != null) {
        fileHandler.close();
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) SimpleFormatter(java.util.logging.SimpleFormatter) Formatter(java.util.logging.Formatter) Logger(java.util.logging.Logger) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Pair(edu.neu.ccs.pyramid.util.Pair) CTFT(edu.neu.ccs.pyramid.calibration.CTFT) SimpleFormatter(java.util.logging.SimpleFormatter) FileHandler(java.util.logging.FileHandler) CTAT(edu.neu.ccs.pyramid.calibration.CTAT) File(java.io.File)

Aggregations

CTAT (edu.neu.ccs.pyramid.calibration.CTAT)3 File (java.io.File)3 CTFT (edu.neu.ccs.pyramid.calibration.CTFT)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 Pair (edu.neu.ccs.pyramid.util.Pair)2 FileHandler (java.util.logging.FileHandler)2 Formatter (java.util.logging.Formatter)2 Logger (java.util.logging.Logger)2 SimpleFormatter (java.util.logging.SimpleFormatter)2 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)1