Search in sources :

Example 41 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App2 method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    main(config);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Example 42 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App2 method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    String modelName = "model_app3";
    //        double featureSamplingRate = config.getDouble("train.featureSamplingRate");
    //        double dataSamplingRate = config.getDouble("train.dataSamplingRate");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
    MultiLabelClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = loadData(config, config.getString("input.testData"));
    }
    int numClasses = dataSet.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
    IMLGradientBoosting boosting;
    if (config.getBoolean("train.warmStart")) {
        boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    } else {
        boosting = new IMLGradientBoosting(numClasses);
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor");
    logger.info("initialing trainer");
    IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
    boolean earlyStop = config.getBoolean("train.earlyStop");
    List<EarlyStopper> earlyStoppers = new ArrayList<>();
    List<Terminator> terminators = new ArrayList<>();
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
            earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
            earlyStoppers.add(earlyStopper);
        }
        for (int l = 0; l < numClasses; l++) {
            Terminator terminator = new Terminator();
            terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
            terminators.add(terminator);
        }
    }
    logger.info("trainer initialized");
    int numLabelsLeftToTrain = numClasses;
    int progressInterval = config.getInt("train.showProgress.interval");
    for (int i = 1; i <= numIterations; i++) {
        logger.info("iteration " + i);
        trainer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance");
            logger.info(new MLMeasures(boosting, dataSet).toString());
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("test set performance");
            logger.info(new MLMeasures(boosting, testSet).toString());
            if (earlyStop) {
                for (int l = 0; l < numClasses; l++) {
                    EarlyStopper earlyStopper = earlyStoppers.get(l);
                    Terminator terminator = terminators.get(l);
                    if (!trainer.getShouldStop()[l]) {
                        double kl = KL(boosting, testSet, l);
                        earlyStopper.add(i, kl);
                        terminator.add(kl);
                        if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                            logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
                            logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                            trainer.setShouldStop(l);
                            numLabelsLeftToTrain -= 1;
                            logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                        }
                    }
                }
            }
        }
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    File serializedModel = new File(output, modelName);
    //todo pick best models
    boosting.serialize(serializedModel);
    logger.info(stopWatch.toString());
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            logger.info("----------------------------------------------------");
            logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
            logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        int limit = config.getInt("report.topFeatures.limit");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(output, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) ArrayList(java.util.ArrayList) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 43 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App2 method report.

static void report(Config config, String dataName, Logger logger) throws Exception {
    logger.info("generating reports for data set " + dataName);
    String output = config.getString("output.folder");
    String modelName = "model_app3";
    File analysisFolder = new File(new File(output, "reports_app3"), dataName + "_reports");
    analysisFolder.mkdirs();
    FileUtils.cleanDirectory(analysisFolder);
    IMLGradientBoosting boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    String predictTarget = config.getString("predict.target");
    PluginPredictor<IMLGradientBoosting> pluginPredictorTmp = null;
    switch(predictTarget) {
        case "subsetAccuracy":
            pluginPredictorTmp = new SubsetAccPredictor(boosting);
            break;
        case "hammingLoss":
            pluginPredictorTmp = new HammingPredictor(boosting);
            break;
        case "instanceFMeasure":
            pluginPredictorTmp = new InstanceF1Predictor(boosting);
            break;
        case "macroFMeasure":
            TunedMarginalClassifier tunedMarginalClassifier = (TunedMarginalClassifier) Serialization.deserialize(new File(output, "predictor_macro_f"));
            pluginPredictorTmp = new MacroF1Predictor(boosting, tunedMarginalClassifier);
            break;
        default:
            throw new IllegalArgumentException("unknown prediction target measure " + predictTarget);
    }
    // just to make Lambda expressions happy
    final PluginPredictor<IMLGradientBoosting> pluginPredictor = pluginPredictorTmp;
    MultiLabelClfDataSet dataSet = loadData(config, dataName);
    MLMeasures mlMeasures = new MLMeasures(pluginPredictor, dataSet);
    mlMeasures.getMacroAverage().setLabelTranslator(boosting.getLabelTranslator());
    logger.info("performance on dataset " + dataName);
    logger.info(mlMeasures.toString());
    boolean simpleCSV = true;
    if (simpleCSV) {
        logger.info("start generating simple CSV report");
        double probThreshold = config.getDouble("report.classProbThreshold");
        File csv = new File(analysisFolder, "report.csv");
        List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> IMLGBInspector.simplePredictionAnalysis(boosting, pluginPredictor, dataSet, i, probThreshold)).collect(Collectors.toList());
        try (BufferedWriter bw = new BufferedWriter(new FileWriter(csv))) {
            for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
                String str = strs.get(i);
                bw.write(str);
            }
        }
        logger.info("finish generating simple CSV report");
    }
    boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
    if (rulesToJson) {
        logger.info("start writing rules to json");
        int ruleLimit = config.getInt("report.rule.limit");
        int numDocsPerFile = config.getInt("report.numDocsPerFile");
        int numFiles = (int) Math.ceil((double) dataSet.getNumDataPoints() / numDocsPerFile);
        double probThreshold = config.getDouble("report.classProbThreshold");
        int labelSetLimit = config.getInt("report.labelSetLimit");
        IntStream.range(0, numFiles).forEach(i -> {
            int start = i * numDocsPerFile;
            int end = start + numDocsPerFile;
            List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, dataSet.getNumDataPoints())).parallel().mapToObj(a -> IMLGBInspector.analyzePrediction(boosting, pluginPredictor, dataSet, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
            ObjectMapper mapper = new ObjectMapper();
            String file = "report_" + (i + 1) + ".json";
            try {
                mapper.writeValue(new File(analysisFolder, file), partition);
            } catch (IOException e) {
                e.printStackTrace();
            }
            logger.info("progress = " + Progress.percentage(i + 1, numFiles));
        });
        logger.info("finish writing rules to json");
    }
    boolean dataInfoToJson = true;
    if (dataInfoToJson) {
        logger.info("start writing data info to json");
        Set<String> modelLabels = IntStream.range(0, boosting.getNumClasses()).mapToObj(i -> boosting.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("dataSet", dataName);
        jsonGenerator.writeNumberField("numClassesInModel", boosting.getNumClasses());
        jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
        Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
        Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
        jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
        jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
        for (String label : dataNotModelLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
        for (String label : modelNotDataLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
        jsonGenerator.writeEndObject();
        jsonGenerator.close();
        logger.info("finish writing data info to json");
    }
    boolean modelConfigToJson = true;
    if (modelConfigToJson) {
        logger.info("start writing model config to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
        logger.info("finish writing model config to json");
    }
    boolean dataConfigToJson = true;
    if (dataConfigToJson) {
        logger.info("start writing data config to json");
        File dataConfigFile = Paths.get(config.getString("input.folder"), "data_sets", dataName, "data_config.json").toFile();
        if (dataConfigFile.exists()) {
            FileUtils.copyFileToDirectory(dataConfigFile, analysisFolder);
        }
        logger.info("finish writing data config to json");
    }
    boolean performanceToJson = true;
    if (performanceToJson) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
    }
    boolean individualPerformance = true;
    if (individualPerformance) {
        logger.info("start writing individual label performance to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
        logger.info("finish writing individual label performance to json");
    }
    logger.info("reports generated");
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) FileWriter(java.io.FileWriter) JsonFactory(com.fasterxml.jackson.core.JsonFactory) BufferedWriter(java.io.BufferedWriter) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) IOException(java.io.IOException) File(java.io.File)

Example 44 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App3 method createApp1Config.

private static Config createApp1Config(Config config) {
    Config app1Config = new Config();
    String[] same = { "output.folder", "output.trainFolder", "output.testFolder", "output.log", "train.feature.useInitialFeatures", "train.feature.categFeature.filter", "train.feature.categFeature.percentThreshold", "train.feature.ngram.n", "train.feature.ngram.minDf", "train.feature.ngram.slop", "train.feature.missingValue", "train.feature.addExternalNgrams", "train.feature.externalNgramFile", "train.feature.analyzer", "train.feature.filterNgramsByKeyWords", "train.feature.filterNgrams.keyWordsFile", "train.feature.filterNgramsByRegex", "train.feature.filterNgrams.regex", "train.feature.useCodeDescription", "train.feature.codeDesc.File", "train.feature.codeDesc.analyzer", "train.feature.codeDesc.matchField", "train.feature.codeDesc.minMatchPercentage", "index.indexName", "index.clusterName", "index.documentType", "index.clientType", "index.hosts", "index.ports", "train.label.field", "train.label.filter", "train.label.filter.prefix", "train.feature.featureFieldPrefix", "train.feature.ngram.extractionFields", "train.splitQuery", "test.splitQuery", "train.feature.ngram.matchScoreType", "createTrainSet", "createTestSet", "train.feature.ngram.selection", "train.feature.ngram.selectPerLabel", "train.label.order" };
    Config.copyExisting(config, app1Config, same);
    return app1Config;
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Example 45 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class App3 method createApp2Config.

private static Config createApp2Config(Config config) {
    Config app2Config = new Config();
    String[] same = { "output.folder", "output.log", "train", "test", "tune", "predict.target", "train.warmStart", "train.usePrior", "train.numIterations", "train.numLeaves", "train.learningRate", "train.minDataPerLeaf", "train.numSplitIntervals", "train.showTrainProgress", "train.showTestProgress", "train.earlyStop.patience", "train.earlyStop.minIterations", "train.earlyStop", "train.earlyStop.absoluteChange", "train.earlyStop.relativeChange", "train.showProgress.interval", "train.generateReports", "tune.data", "tune.FMeasure.beta", "report.topFeatures.limit", "report.rule.limit", "report.numDocsPerFile", "report.classProbThreshold", "report.labelSetLimit", "report.showPredictionDetail" };
    Config.copyExisting(config, app2Config, same);
    app2Config.setString("input.folder", config.getString("output.folder"));
    app2Config.setString("input.trainData", config.getString("output.trainFolder"));
    app2Config.setString("input.testData", config.getString("output.testFolder"));
    return app2Config;
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)59 File (java.io.File)35 Collectors (java.util.stream.Collectors)18 FileUtils (org.apache.commons.io.FileUtils)18 StopWatch (org.apache.commons.lang3.time.StopWatch)18 Serialization (edu.neu.ccs.pyramid.util.Serialization)17 BufferedWriter (java.io.BufferedWriter)17 FileWriter (java.io.FileWriter)17 Paths (java.nio.file.Paths)17 IntStream (java.util.stream.IntStream)16 Pair (edu.neu.ccs.pyramid.util.Pair)15 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)14 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)14 java.util (java.util)14 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)13 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)13 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)12 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)12 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)12 IOException (java.io.IOException)8