Search in sources :

Example 11 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMGB method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    VERBOSE = config.getBoolean("output.verbose");
    new File(config.getString("output.dir")).mkdirs();
    if (config.getBoolean("tune")) {
        System.out.println("============================================================");
        System.out.println("Start hyper parameter tuning");
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        List<TuneResult> tuneResults = new ArrayList<>();
        List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
        List<Integer> leaveNums = config.getIntegers("tune.numLeaves.candidates");
        List<Integer> components = config.getIntegers("tune.numComponents.candidates");
        for (int numLeaves : leaveNums) {
            for (int component : components) {
                StopWatch stopWatch1 = new StopWatch();
                stopWatch1.start();
                HyperParameters hyperParameters = new HyperParameters();
                hyperParameters.numComponents = component;
                hyperParameters.numLeaves = numLeaves;
                System.out.println("---------------------------");
                System.out.println("Trying hyper parameters:");
                System.out.println("train.numComponents = " + hyperParameters.numComponents);
                System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
                TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
                System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
                System.out.println("Validation performance = " + tuneResult.performance);
                tuneResults.add(tuneResult);
                System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
            }
        }
        Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
        TuneResult best;
        String predictTarget = config.getString("tune.targetMetric");
        switch(predictTarget) {
            case "instance_set_accuracy":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_f1":
                best = tuneResults.stream().max(comparator).get();
                break;
            case "instance_hamming_loss":
                best = tuneResults.stream().min(comparator).get();
                break;
            default:
                throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
        }
        System.out.println("---------------------------");
        System.out.println("Hyper parameter tuning done.");
        System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
        System.out.println("Best validation performance = " + best.performance);
        System.out.println("Best hyper parameters:");
        System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
        System.out.println("train.numLeaves = " + best.hyperParameters.numLeaves);
        System.out.println("train.iterations = " + best.hyperParameters.iterations);
        Config tunedHypers = best.hyperParameters.asConfig();
        tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
        System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
        System.out.println("============================================================");
    }
    if (config.getBoolean("train")) {
        System.out.println("============================================================");
        if (config.getBoolean("train.useTunedHyperParameters")) {
            File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
            if (!hyperFile.exists()) {
                System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
                System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
                System.exit(1);
            }
            Config tunedHypers = new Config(hyperFile);
            HyperParameters hyperParameters = new HyperParameters(tunedHypers);
            System.out.println("Start training with tuned hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        } else {
            HyperParameters hyperParameters = new HyperParameters(config);
            System.out.println("Start training with given hyper parameters:");
            System.out.println("train.numComponents = " + hyperParameters.numComponents);
            System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
            System.out.println("train.iterations = " + hyperParameters.iterations);
            MultiLabelClfDataSet trainSet = loadTrainData(config);
            train(config, hyperParameters, trainSet);
        }
        System.out.println("============================================================");
    }
    if (config.getBoolean("test")) {
        System.out.println("============================================================");
        test(config);
        System.out.println("============================================================");
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File)

Example 12 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMGB method reportAccPrediction.

private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
    String output = config.getString("output.dir");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = accPredictor.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance set accuracy optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) FileWriter(java.io.FileWriter) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) BufferedWriter(java.io.BufferedWriter)

Example 13 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class CBMGB method reportF1Prediction.

private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance F1 optimal predictor");
    String output = config.getString("output.dir");
    PluginF1 pluginF1 = new PluginF1(cbm);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
    pluginF1.setSupport(support);
    pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = pluginF1.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance F1 optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSetUtil(edu.neu.ccs.pyramid.dataset.DataSetUtil) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 14 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class FoldPartitioner method main.

public static void main(String[] args) throws Exception {
    Config config = new Config(args[0]);
    System.out.println(config);
    String dataType = config.getString("dataSetType");
    switch(dataType) {
        case "clf":
            partitionClfData(config);
            break;
        case "reg":
            partitionRegData(config);
            break;
        case "mlclf":
            partitionMLClfData(config);
            break;
        default:
            throw new IllegalArgumentException("illegal dataSetType");
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Example 15 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App1 method loadData.

//todo keep track of feature types(numerical /binary)
static MultiLabelClfDataSet loadData(Config config, MultiLabelIndex index, FeatureList featureList, IdTranslator idTranslator, int totalDim, LabelTranslator labelTranslator, String docFilter) throws Exception {
    File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
    Config savedConfig = new Config(new File(metaDataFolder, "saved_config_app1"));
    int numDataPoints = idTranslator.numData();
    int numClasses = labelTranslator.getNumClasses();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(totalDim).numClasses(numClasses).density(Density.SPARSE_RANDOM).missingValue(savedConfig.getBoolean("train.feature.missingValue")).build();
    for (int i = 0; i < numDataPoints; i++) {
        String dataIndexId = idTranslator.toExtId(i);
        List<String> extMultiLabel = index.getExtMultiLabel(dataIndexId);
        if (savedConfig.getBoolean("train.label.filter")) {
            String prefix = savedConfig.getString("train.label.filter.prefix");
            extMultiLabel = extMultiLabel.stream().filter(extLabel -> extLabel.startsWith(prefix)).collect(Collectors.toList());
        }
        for (String extLabel : extMultiLabel) {
            int intLabel = labelTranslator.toIntLabel(extLabel);
            dataSet.addLabel(i, intLabel);
        }
    }
    String matchScoreTypeString = savedConfig.getString("train.feature.ngram.matchScoreType");
    FeatureLoader.MatchScoreType matchScoreType;
    switch(matchScoreTypeString) {
        case "es_original":
            matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
            break;
        case "binary":
            matchScoreType = FeatureLoader.MatchScoreType.BINARY;
            break;
        case "frequency":
            matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
            break;
        case "tfifl":
            matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
            break;
        default:
            throw new IllegalArgumentException("unknown ngramMatchScoreType");
    }
    FeatureLoader.loadFeatures(index, dataSet, featureList, idTranslator, matchScoreType, docFilter);
    dataSet.setIdTranslator(idTranslator);
    dataSet.setLabelTranslator(labelTranslator);
    return dataSet;
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) FeatureLoader(edu.neu.ccs.pyramid.elasticsearch.FeatureLoader) File(java.io.File)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)59 File (java.io.File)35 Collectors (java.util.stream.Collectors)18 FileUtils (org.apache.commons.io.FileUtils)18 StopWatch (org.apache.commons.lang3.time.StopWatch)18 Serialization (edu.neu.ccs.pyramid.util.Serialization)17 BufferedWriter (java.io.BufferedWriter)17 FileWriter (java.io.FileWriter)17 Paths (java.nio.file.Paths)17 IntStream (java.util.stream.IntStream)16 Pair (edu.neu.ccs.pyramid.util.Pair)15 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)14 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)14 java.util (java.util)14 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)13 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)13 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)12 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)12 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)12 IOException (java.io.IOException)8