Search in sources :

Example 51 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class ClusterLabels method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    fitModel(config);
    plot(config);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Example 52 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    VERBOSE = config.getBoolean("output.verbose");
    new File(config.getString("output.dir")).mkdirs();
    if (config.getBoolean("tune")) {
        System.out.println("============================================================");
        System.out.println("Start hyper parameter tuning");
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        List<TuneResult> tuneResults = new ArrayList<>();
        List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
        List<Double> variances = config.getDoubles("tune.variance.candidates");
        List<Integer> components = config.getIntegers("tune.numComponents.candidates");
        for (double variance : variances) {
            for (int component : components) {
                StopWatch stopWatch1 = new StopWatch();
                stopWatch1.start();
                HyperParameters hyperParameters = new HyperParameters();
                hyperParameters.numComponents = component;
                hyperParameters.variance = variance;
                System.out.println("---------------------------");
                System.out.println("Trying hyper parameters:");
                System.out.println("train.numComponents = " + hyperParameters.numComponents);
                System.out.println("train.variance = " + hyperParameters.variance);
                TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
                System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
                System.out.println("Validation performance = " + tuneResult.performance);
                tuneResults.add(tuneResult);
                System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
            }
        }
        Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
        TuneResult best;
        String predictTarget = config.getString("tune.targetMetric");
        switch(predictTarget) {
            case "instance_set_accuracy":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_f1":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_hamming_loss":
                best = tuneResults.stream().min(comparator).get();
                break;
            default:
                throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
        }
        System.out.println("---------------------------");
        System.out.println("Hyper parameter tuning done.");
        System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
        System.out.println("Best validation performance = " + best.performance);
        System.out.println("Best hyper parameters:");
        System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
        System.out.println("train.variance = " + best.hyperParameters.variance);
        System.out.println("train.iterations = " + best.hyperParameters.iterations);
        Config tunedHypers = best.hyperParameters.asConfig();
        tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
        System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
        System.out.println("============================================================");
    }
    if (config.getBoolean("train")) {
        System.out.println("============================================================");
        if (config.getBoolean("train.useTunedHyperParameters")) {
            File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
            if (!hyperFile.exists()) {
                System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
                System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
                System.exit(1);
            }
            Config tunedHypers = new Config(hyperFile);
            HyperParameters hyperParameters = new HyperParameters(tunedHypers);
            System.out.println("Start training with tuned hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.variance = " + hyperParameters.variance);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        } else {
            HyperParameters hyperParameters = new HyperParameters(config);
            System.out.println("Start training with given hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.variance = " + hyperParameters.variance);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        }
        System.out.println("============================================================");
    }
    if (config.getBoolean("test")) {
        System.out.println("============================================================");
        test(config);
        System.out.println("============================================================");
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File)

Example 53 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMLR method reportGeneral.

private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("computing other predictor-independent metrics");
    String output = config.getString("output.dir");
    File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
    double labelProbThreshold = config.getDouble("report.labelProbThreshold");
    try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
            br.newLine();
        }
    }
    System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
    List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
    // Here we do not use approximation
    double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
    double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
    File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
    FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
    System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
    System.out.println("average log likelihood of the test ground truth label sets = " + average);
    if (!unobservedLabels.isEmpty()) {
        System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
        System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
        System.out.println(ListUtil.toSimpleString(unobservedLabels));
    }
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 54 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMEN method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 55 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMGB method reportGeneral.

private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("computing other predictor-independent metrics");
    String output = config.getString("output.dir");
    File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
    double labelProbThreshold = config.getDouble("report.labelProbThreshold");
    try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i), labelProbThreshold));
            br.newLine();
        }
    }
    System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
    List<Integer> unobservedLabels = Arrays.stream(FileUtils.readFileToString(new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
    // Here we do not use approximation
    double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
    double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
    File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
    FileUtils.writeStringToFile(logLikelihoodFile, PrintUtil.toMutipleLines(logLikelihoods));
    System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
    System.out.println("average log likelihood of the test ground truth label sets = " + average);
    if (!unobservedLabels.isEmpty()) {
        System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
        System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
        System.out.println(ListUtil.toSimpleString(unobservedLabels));
    }
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)59 File (java.io.File)35 Collectors (java.util.stream.Collectors)18 FileUtils (org.apache.commons.io.FileUtils)18 StopWatch (org.apache.commons.lang3.time.StopWatch)18 Serialization (edu.neu.ccs.pyramid.util.Serialization)17 BufferedWriter (java.io.BufferedWriter)17 FileWriter (java.io.FileWriter)17 Paths (java.nio.file.Paths)17 IntStream (java.util.stream.IntStream)16 Pair (edu.neu.ccs.pyramid.util.Pair)15 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)14 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)14 java.util (java.util)14 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)13 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)13 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)12 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)12 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)12 IOException (java.io.IOException)8