Search in sources :

Example 71 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class ArffFormat method writeConfigFile.

private static void writeConfigFile(ClfDataSet dataSet, File arffFile) {
    File configFile = new File(arffFile, ARFF_CONFIG_FILE_NAME);
    Config config = new Config();
    config.setInt(ARFF_CONFIG_NUM_DATA_POINTS, dataSet.getNumDataPoints());
    config.setInt(ARFF_CONFIG_NUM_FEATURES, dataSet.getNumFeatures());
    config.setInt(ARFF_CONFIG_NUM_CLASSES, dataSet.getNumClasses());
    try {
        config.store(configFile);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) File(java.io.File) IOException(java.io.IOException)

Example 72 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class ArffFormat method writeConfigFile.

private static void writeConfigFile(MultiLabelClfDataSet dataSet, File arffFile) {
    File configFile = new File(arffFile, ARFF_CONFIG_FILE_NAME);
    Config config = new Config();
    config.setInt(ARFF_CONFIG_NUM_DATA_POINTS, dataSet.getNumDataPoints());
    config.setInt(ARFF_CONFIG_NUM_FEATURES, dataSet.getNumFeatures());
    config.setInt(ARFF_CONFIG_NUM_CLASSES, dataSet.getNumClasses());
    try {
        config.store(configFile);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) File(java.io.File) IOException(java.io.IOException)

Example 73 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class BRRerank method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    if (config.getBoolean("trainBR")) {
        System.out.println("Start training BR classifier");
        Config brConfig = produceBRConfig(config);
        CBMEN.main(brConfig);
        System.out.println("Finish training BR classifier");
    }
    Config caliConfig = produceCaliConfig(config);
    calibrate(caliConfig);
    report(config);
    calibration_eval(config);
    classification_eval(config);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Example 74 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class BRRerank method report.

private static void report(Config config) throws Exception {
    MultiLabelClfDataSet dataset = TRECFormat.loadMultiLabelClfDataSet(Paths.get(config.getString("dataPath"), "test").toFile(), DataSetType.ML_CLF_SPARSE, true);
    CBM cbm = (CBM) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "model"));
    cbm.setAllowEmpty(true);
    MultiLabelClassifier classifier = null;
    LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "label_calibrator"));
    VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "set_calibrator"));
    PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "calibration_feature_extractor"));
    CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
    switch(config.getString("predictMode")) {
        case "independent":
            classifier = new IndependentPredictor(cbm, labelCalibrator);
            break;
        case "rerank":
            classifier = (Reranker) setCalibrator;
            break;
        default:
            throw new IllegalArgumentException("illegal predict.mode");
    }
    MultiLabel[] predictions = classifier.predict(dataset);
    List<CalibInfo> confidenceScores = IntStream.range(0, dataset.getNumDataPoints()).parallel().boxed().map(i -> {
        CalibrationDataGenerator.CalibrationInstance predictionInstance = calibrationDataGenerator.createInstance(cbm, dataset.getRow(i), predictions[i], dataset.getMultiLabels()[i], "accuracy");
        double calibrated = setCalibrator.calibrate(predictionInstance.vector);
        CalibInfo calibInfo = new CalibInfo();
        calibInfo.uncalibrated = predictionInstance.vector.get(0);
        calibInfo.calibrated = calibrated;
        calibInfo.accuracy = predictionInstance.correctness;
        return calibInfo;
    }).collect(Collectors.toList());
    StringBuilder stringBuilder = new StringBuilder();
    stringBuilder.append("set_prediction").append("\t").append("uncalibrated_confidence").append("\t").append("calibrated_confidence").append("\t").append("ground_truth").append("\t").append("set_accuracy").append("\n");
    for (int i = 0; i < dataset.getNumDataPoints(); i++) {
        stringBuilder.append(predictions[i]).append("\t").append(confidenceScores.get(i).uncalibrated).append("\t").append(confidenceScores.get(i).calibrated).append("\t").append(dataset.getMultiLabels()[i]).append("\t").append(predictions[i].equals(dataset.getMultiLabels()[i]) ? 1 : 0).append("\n");
    }
    FileUtils.writeStringToFile(Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toFile(), stringBuilder.toString());
    System.out.println("set predictions and confidence scores are saved to " + Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toString() + "\n");
}
Also used : IntStream(java.util.stream.IntStream) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) CalibrationEval(edu.neu.ccs.pyramid.eval.CalibrationEval) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor)

Example 75 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    VERBOSE = config.getBoolean("output.verbose");
    new File(config.getString("output.dir")).mkdirs();
    if (config.getBoolean("tune")) {
        System.out.println("============================================================");
        System.out.println("Start hyper parameter tuning");
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        List<TuneResult> tuneResults = new ArrayList<>();
        List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
        List<Double> variances = config.getDoubles("tune.variance.candidates");
        List<Integer> components = config.getIntegers("tune.numComponents.candidates");
        for (double variance : variances) {
            for (int component : components) {
                StopWatch stopWatch1 = new StopWatch();
                stopWatch1.start();
                HyperParameters hyperParameters = new HyperParameters();
                hyperParameters.numComponents = component;
                hyperParameters.variance = variance;
                System.out.println("---------------------------");
                System.out.println("Trying hyper parameters:");
                System.out.println("train.numComponents = " + hyperParameters.numComponents);
                System.out.println("train.variance = " + hyperParameters.variance);
                TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
                System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
                System.out.println("Validation performance = " + tuneResult.performance);
                tuneResults.add(tuneResult);
                System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
            }
        }
        Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
        TuneResult best;
        String predictTarget = config.getString("tune.targetMetric");
        switch(predictTarget) {
            case "instance_set_accuracy":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_f1":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_hamming_loss":
                best = tuneResults.stream().min(comparator).get();
                break;
            case "instance_log_likelihood":
                best = tuneResults.stream().max(comparator).get();
                break;
            default:
                throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1, instance_hamming_loss, instance_log_likelihood");
        }
        System.out.println("---------------------------");
        System.out.println("Hyper parameter tuning done.");
        System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
        System.out.println("Best validation performance = " + best.performance);
        System.out.println("Best hyper parameters:");
        System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
        System.out.println("train.variance = " + best.hyperParameters.variance);
        System.out.println("train.iterations = " + best.hyperParameters.iterations);
        Config tunedHypers = best.hyperParameters.asConfig();
        tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
        System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
        System.out.println("============================================================");
    }
    if (config.getBoolean("train")) {
        System.out.println("============================================================");
        if (config.getBoolean("train.useTunedHyperParameters")) {
            File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
            if (!hyperFile.exists()) {
                System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
                System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
                System.exit(1);
            }
            Config tunedHypers = new Config(hyperFile);
            HyperParameters hyperParameters = new HyperParameters(tunedHypers);
            System.out.println("Start training with tuned hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.variance = " + hyperParameters.variance);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        } else {
            HyperParameters hyperParameters = new HyperParameters(config);
            System.out.println("Start training with given hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.variance = " + hyperParameters.variance);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        }
        System.out.println("============================================================");
    }
    if (config.getBoolean("test")) {
        System.out.println("============================================================");
        test(config);
        System.out.println("============================================================");
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)119 File (java.io.File)68 Collectors (java.util.stream.Collectors)40 FileUtils (org.apache.commons.io.FileUtils)40 Paths (java.nio.file.Paths)39 IntStream (java.util.stream.IntStream)37 Pair (edu.neu.ccs.pyramid.util.Pair)36 Serialization (edu.neu.ccs.pyramid.util.Serialization)35 StopWatch (org.apache.commons.lang3.time.StopWatch)34 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)33 BufferedWriter (java.io.BufferedWriter)32 FileWriter (java.io.FileWriter)32 java.util (java.util)32 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)31 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)29 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)28 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)26 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)25 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)25 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)22