Search in sources :

Example 6 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class AdaBoostMHTrainer method fitClassK.

private RegressionTree fitClassK(int k) {
    double[] probs = weightMatrix.getProbsForClass(k);
    double match = IntStream.range(0, dataSet.getNumDataPoints()).parallel().filter(i -> labels[i][k]).mapToDouble(i -> weightMatrix.getProbsForData(i)[k]).sum();
    double notMatch = IntStream.range(0, dataSet.getNumDataPoints()).parallel().filter(i -> !labels[i][k]).mapToDouble(i -> weightMatrix.getProbsForData(i)[k]).sum();
    StumpInfo optimal = IntStream.range(0, dataSet.getNumFeatures()).parallel().mapToObj(j -> {
        double matchOccur = 0;
        double notMatchOccur = 0;
        Vector vector = dataSet.getColumn(j);
        for (Vector.Element element : vector.nonZeroes()) {
            int i = element.index();
            double prob = probs[i];
            if (labels[i][k]) {
                matchOccur += prob;
            } else {
                notMatchOccur += prob;
            }
        }
        double matchNotOccur = match - matchOccur;
        double notMatchNotOccur = notMatch - notMatchOccur;
        StumpInfo stumpInfo = new StumpInfo();
        stumpInfo.featureIndex = j;
        stumpInfo.matchOccur = matchOccur;
        stumpInfo.matchNotOccur = matchNotOccur;
        stumpInfo.notMatchOccur = notMatchOccur;
        stumpInfo.notMatchNotOccur = notMatchNotOccur;
        return stumpInfo;
    }).min(Comparator.comparing(StumpInfo::getObjective)).get();
    double smooth = 1.0 / (dataSet.getNumDataPoints() * dataSet.getNumClasses());
    double leftOutput = 0.5 * Math.log((optimal.matchNotOccur + smooth) / (optimal.notMatchNotOccur + smooth));
    double rightOutput = 0.5 * Math.log((optimal.matchOccur + smooth) / (optimal.notMatchOccur + smooth));
    RegressionTree tree = RegressionTree.newStump(optimal.featureIndex, 0, leftOutput, rightOutput);
    tree.setFeatureList(dataSet.getFeatureList());
    return tree;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) WeightMatrix(edu.neu.ccs.pyramid.dataset.WeightMatrix) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) ScoreMatrix(edu.neu.ccs.pyramid.dataset.ScoreMatrix) Logger(org.apache.logging.log4j.Logger) Regressor(edu.neu.ccs.pyramid.regression.Regressor) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) Vector(org.apache.mahout.math.Vector) MLPriorProbClassifier(edu.neu.ccs.pyramid.multilabel_classification.MLPriorProbClassifier) Comparator(java.util.Comparator) LogManager(org.apache.logging.log4j.LogManager) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) Vector(org.apache.mahout.math.Vector)

Example 7 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class HMLGBInspector method decisionProcess.

//
//    public static String analyzeMistake(HMLGradientBoosting boosting, Vector vector,
//                                        MultiLabel trueLabel, MultiLabel prediction,
//                                        LabelTranslator labelTranslator, int limit){
//        StringBuilder sb = new StringBuilder();
//        List<Integer> difference = MultiLabel.symmetricDifference(trueLabel,prediction).stream().sorted().collect(Collectors.toList());
//
//        double[] classScores = boosting.predictClassScores(vector);
//        sb.append("score for the true labels ").append(trueLabel)
//                .append("(").append(trueLabel.toStringWithExtLabels(labelTranslator)).append(") = ");
//        sb.append(boosting.calAssignmentScore(trueLabel,classScores)).append("\n");
//
//        sb.append("score for the predicted labels ").append(prediction)
//                .append("(").append(prediction.toStringWithExtLabels(labelTranslator)).append(") = ");;
//        sb.append(boosting.calAssignmentScore(prediction,classScores)).append("\n");
//
//        for (int k: difference){
//            sb.append("score for class ").append(k).append("(").append(labelTranslator.toExtLabel(k)).append(")")
//                    .append(" =").append(classScores[k]).append("\n");
//        }
//
//        for (int k: difference){
//            sb.append("decision process for class ").append(k).append("(").append(labelTranslator.toExtLabel(k)).append("):\n");
//            sb.append(decisionProcess(boosting,vector,k,limit));
//            sb.append("--------------------------------------------------").append("\n");
//        }
//
//        return sb.toString();
//    }
public static ClassScoreCalculation decisionProcess(HMLGradientBoosting boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    double prob = boosting.predictClassProb(vector, classIndex);
    classScoreCalculation.setClassProbability(prob);
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Example 8 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class IMLGBInspector method topFeatures.

//todo: consider newton step and learning rate
/**
     * only trees are considered
     * @param boosting
     * @param classIndex
     * @return list of feature index and feature name pairs
     */
public static TopFeatures topFeatures(IMLGradientBoosting boosting, int classIndex, int limit) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
    for (RegressionTree tree : trees) {
        Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
        for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
            Feature feature = entry.getKey();
            Double contribution = entry.getValue();
            double oldValue = totalContributions.getOrDefault(feature, 0.0);
            double newValue = oldValue + contribution;
            totalContributions.put(feature, newValue);
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = boosting.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IntStream(java.util.stream.IntStream) java.util(java.util) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature)

Example 9 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class HMLGBTrainer method fitClassK.

/**
     * parallel
     * find the best regression tree for class k
     * apply newton step and learning rate
     * @param k class index
     * @return regressionTreeLk, shrunk
     * @throws Exception
     */
private RegressionTree fitClassK(int k) {
    double[] gradients = gradientMatrix.getGradientsForClass(k);
    int numClasses = this.config.getDataSet().getNumClasses();
    double learningRate = this.config.getLearningRate();
    LeafOutputCalculator leafOutputCalculator = new HMLGBLeafOutputCalculator(numClasses);
    RegTreeConfig regTreeConfig = new RegTreeConfig();
    regTreeConfig.setMaxNumLeaves(this.config.getNumLeaves());
    regTreeConfig.setMinDataPerLeaf(this.config.getMinDataPerLeaf());
    regTreeConfig.setNumSplitIntervals(this.config.getNumSplitIntervals());
    RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, this.config.getDataSet(), gradients, leafOutputCalculator);
    regressionTree.shrink(learningRate);
    return regressionTree;
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) LeafOutputCalculator(edu.neu.ccs.pyramid.regression.regression_tree.LeafOutputCalculator)

Example 10 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class IMLGBInspector method decisionProcess.

public static ClassScoreCalculation decisionProcess(IMLGradientBoosting boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    double prob = boosting.predictClassProb(vector, classIndex);
    classScoreCalculation.setClassProbability(prob);
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Aggregations

RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)14 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)10 Feature (edu.neu.ccs.pyramid.feature.Feature)6 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)6 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)6 Collectors (java.util.stream.Collectors)6 Vector (org.apache.mahout.math.Vector)6 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)5 java.util (java.util)5 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)4 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)3 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)3 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)3 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)3 Pair (edu.neu.ccs.pyramid.util.Pair)3 IntStream (java.util.stream.IntStream)3 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)2