Search in sources :

Example 1 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMGB method reportHammingPrediction.

private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
    String output = config.getString("output.dir");
    MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
    marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = marginalPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance Hamming loss optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 2 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMGB method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    VERBOSE = config.getBoolean("output.verbose");
    new File(config.getString("output.dir")).mkdirs();
    if (config.getBoolean("tune")) {
        System.out.println("============================================================");
        System.out.println("Start hyper parameter tuning");
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        List<TuneResult> tuneResults = new ArrayList<>();
        List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
        List<Integer> leaveNums = config.getIntegers("tune.numLeaves.candidates");
        List<Integer> components = config.getIntegers("tune.numComponents.candidates");
        for (int numLeaves : leaveNums) {
            for (int component : components) {
                StopWatch stopWatch1 = new StopWatch();
                stopWatch1.start();
                HyperParameters hyperParameters = new HyperParameters();
                hyperParameters.numComponents = component;
                hyperParameters.numLeaves = numLeaves;
                System.out.println("---------------------------");
                System.out.println("Trying hyper parameters:");
                System.out.println("train.numComponents = " + hyperParameters.numComponents);
                System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
                TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
                System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
                System.out.println("Validation performance = " + tuneResult.performance);
                tuneResults.add(tuneResult);
                System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
            }
        }
        Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
        TuneResult best;
        String predictTarget = config.getString("tune.targetMetric");
        switch(predictTarget) {
            case "instance_set_accuracy":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_f1":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_hamming_loss":
                best = tuneResults.stream().min(comparator).get();
                break;
            default:
                throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
        }
        System.out.println("---------------------------");
        System.out.println("Hyper parameter tuning done.");
        System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
        System.out.println("Best validation performance = " + best.performance);
        System.out.println("Best hyper parameters:");
        System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
        System.out.println("train.numLeaves = " + best.hyperParameters.numLeaves);
        System.out.println("train.iterations = " + best.hyperParameters.iterations);
        Config tunedHypers = best.hyperParameters.asConfig();
        tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
        System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
        System.out.println("============================================================");
    }
    if (config.getBoolean("train")) {
        System.out.println("============================================================");
        if (config.getBoolean("train.useTunedHyperParameters")) {
            File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
            if (!hyperFile.exists()) {
                System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
                System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
                System.exit(1);
            }
            Config tunedHypers = new Config(hyperFile);
            HyperParameters hyperParameters = new HyperParameters(tunedHypers);
            System.out.println("Start training with tuned hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        } else {
            HyperParameters hyperParameters = new HyperParameters(config);
            System.out.println("Start training with given hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        }
        System.out.println("============================================================");
    }
    if (config.getBoolean("test")) {
        System.out.println("============================================================");
        test(config);
        System.out.println("============================================================");
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File)

Example 3 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMGB method reportAccPrediction.

private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
    String output = config.getString("output.dir");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = accPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance set accuracy optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 4 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMGB method reportF1Prediction.

private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance F1 optimal predictor");
    String output = config.getString("output.dir");
    PluginF1 pluginF1 = new PluginF1(cbm);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
    pluginF1.setSupport(support);
    pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = pluginF1.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance F1 optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 5 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class MultiLabelSynthesizer method crfArgmax.

public static MultiLabelClfDataSet crfArgmax() {
    int numData = 1000;
    int numClass = 4;
    int numFeature = 10;
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    List<MultiLabel> support = Enumerator.enumerate(numClass);
    CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
    // assign labels
    for (int i = 0; i < numData; i++) {
        MultiLabel label = predictor.predict(dataSet.getRow(i));
        dataSet.setLabels(i, label);
    }
    return dataSet;
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SubsetAccPredictor(edu.neu.ccs.pyramid.multilabel_classification.crf.SubsetAccPredictor) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)48 File (java.io.File)24 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)23 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)12 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)9 Vector (org.apache.mahout.math.Vector)9 Config (edu.neu.ccs.pyramid.configuration.Config)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 DenseVector (org.apache.mahout.math.DenseVector)7 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 java.util (java.util)5 Collectors (java.util.stream.Collectors)5 IntStream (java.util.stream.IntStream)5 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)4 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)4 MLScorer (edu.neu.ccs.pyramid.multilabel_classification.MLScorer)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 AccScorer (edu.neu.ccs.pyramid.multilabel_classification.AccScorer)3