Search in sources :

Example 6 with TopFeatures

use of edu.neu.ccs.pyramid.feature.TopFeatures in project pyramid by cheng-li.

the class MLLogisticRegressionInspector method topFeatures.

public static TopFeatures topFeatures(MLLogisticRegression logisticRegression, int classIndex, int limit) {
    FeatureList featureList = logisticRegression.getFeatureList();
    Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
    Comparator<FeatureUtility> comparator = Comparator.comparing(FeatureUtility::getUtility);
    List<Feature> list = IntStream.range(0, weights.size()).mapToObj(i -> new FeatureUtility(featureList.get(i)).setUtility(weights.get(i))).filter(featureUtility -> featureUtility.getUtility() > 0).sorted(comparator.reversed()).map(FeatureUtility::getFeature).limit(limit).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : FeatureUtility(edu.neu.ccs.pyramid.feature.FeatureUtility) FeatureUtility(edu.neu.ccs.pyramid.feature.FeatureUtility) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) IntStream(java.util.stream.IntStream) ClassProbability(edu.neu.ccs.pyramid.classification.ClassProbability) ConstantRule(edu.neu.ccs.pyramid.regression.ConstantRule) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Rule(edu.neu.ccs.pyramid.regression.Rule) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) PredictionAnalysis(edu.neu.ccs.pyramid.classification.PredictionAnalysis) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LinearRule(edu.neu.ccs.pyramid.regression.LinearRule) Comparator(java.util.Comparator) ClassScoreCalculation(edu.neu.ccs.pyramid.regression.ClassScoreCalculation) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Vector(org.apache.mahout.math.Vector) Feature(edu.neu.ccs.pyramid.feature.Feature)

Example 7 with TopFeatures

use of edu.neu.ccs.pyramid.feature.TopFeatures in project pyramid by cheng-li.

the class AdaBoostMHInspector method topFeatures.

public static TopFeatures topFeatures(AdaBoostMH boosting, int classIndex, int limit) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
    for (RegressionTree tree : trees) {
        Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
        for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
            Feature feature = entry.getKey();
            Double contribution = entry.getValue();
            double oldValue = totalContributions.getOrDefault(feature, 0.0);
            double newValue = oldValue + contribution;
            totalContributions.put(feature, newValue);
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = boosting.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLPlattScaling) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MLACPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLACPlattScaling) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PlattScaling(edu.neu.ccs.pyramid.classification.PlattScaling) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Example 8 with TopFeatures

use of edu.neu.ccs.pyramid.feature.TopFeatures in project pyramid by cheng-li.

the class App2 method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    String modelName = "model_app3";
    //        double featureSamplingRate = config.getDouble("train.featureSamplingRate");
    //        double dataSamplingRate = config.getDouble("train.dataSamplingRate");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
    MultiLabelClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = loadData(config, config.getString("input.testData"));
    }
    int numClasses = dataSet.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
    IMLGradientBoosting boosting;
    if (config.getBoolean("train.warmStart")) {
        boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    } else {
        boosting = new IMLGradientBoosting(numClasses);
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor");
    logger.info("initialing trainer");
    IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
    boolean earlyStop = config.getBoolean("train.earlyStop");
    List<EarlyStopper> earlyStoppers = new ArrayList<>();
    List<Terminator> terminators = new ArrayList<>();
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
            earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
            earlyStoppers.add(earlyStopper);
        }
        for (int l = 0; l < numClasses; l++) {
            Terminator terminator = new Terminator();
            terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
            terminators.add(terminator);
        }
    }
    logger.info("trainer initialized");
    int numLabelsLeftToTrain = numClasses;
    int progressInterval = config.getInt("train.showProgress.interval");
    for (int i = 1; i <= numIterations; i++) {
        logger.info("iteration " + i);
        trainer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance");
            logger.info(new MLMeasures(boosting, dataSet).toString());
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("test set performance");
            logger.info(new MLMeasures(boosting, testSet).toString());
            if (earlyStop) {
                for (int l = 0; l < numClasses; l++) {
                    EarlyStopper earlyStopper = earlyStoppers.get(l);
                    Terminator terminator = terminators.get(l);
                    if (!trainer.getShouldStop()[l]) {
                        double kl = KL(boosting, testSet, l);
                        earlyStopper.add(i, kl);
                        terminator.add(kl);
                        if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                            logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
                            logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                            trainer.setShouldStop(l);
                            numLabelsLeftToTrain -= 1;
                            logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                        }
                    }
                }
            }
        }
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    File serializedModel = new File(output, modelName);
    //todo pick best models
    boosting.serialize(serializedModel);
    logger.info(stopWatch.toString());
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            logger.info("----------------------------------------------------");
            logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
            logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        int limit = config.getInt("report.topFeatures.limit");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(output, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) ArrayList(java.util.ArrayList) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Aggregations

Feature (edu.neu.ccs.pyramid.feature.Feature)8 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)8 Collectors (java.util.stream.Collectors)8 Vector (org.apache.mahout.math.Vector)7 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)6 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)6 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)5 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)5 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)5 java.util (java.util)5 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)4 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)4 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)4 IntStream (java.util.stream.IntStream)4 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)3 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)3 IMLGradientBoosting (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)3 Pair (edu.neu.ccs.pyramid.util.Pair)3 List (java.util.List)3 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2