Search in sources :

Example 1 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class LSBoostInspector method topFeatures.

public static TopFeatures topFeatures(LSBoost boosting) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    List<Regressor> regressors = boosting.getEnsemble(0).getRegressors();
    List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
    for (RegressionTree tree : trees) {
        Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
        for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
            Feature feature = entry.getKey();
            Double contribution = entry.getValue();
            double oldValue = totalContributions.getOrDefault(feature, 0.0);
            double newValue = oldValue + contribution;
            totalContributions.put(feature, newValue);
        }
    }
    System.out.println(totalContributions);
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    return topFeatures;
}
Also used : RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Regressor(edu.neu.ccs.pyramid.regression.Regressor) Map(java.util.Map) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) HashMap(java.util.HashMap) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost) Comparator(java.util.Comparator) Collectors(java.util.stream.Collectors) HashMap(java.util.HashMap) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) Regressor(edu.neu.ccs.pyramid.regression.Regressor) Map(java.util.Map) HashMap(java.util.HashMap)

Example 2 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class RulesTest method test1.

static void test1() throws Exception {
    int numLeaves = 4;
    RegDataSet dataSet = StandardFormat.loadRegDataSet("/Users/chengli/Datasets/slice_location/standard/featureList.txt", "/Users/chengli/Datasets/slice_location/standard/labels.txt", ",", DataSetType.REG_DENSE, false);
    System.out.println(dataSet.isDense());
    int[] activeFeatures = IntStream.range(0, dataSet.getNumFeatures()).toArray();
    int[] activeDataPoints = IntStream.range(0, dataSet.getNumDataPoints()).toArray();
    RegTreeConfig regTreeConfig = new RegTreeConfig();
    regTreeConfig.setMaxNumLeaves(numLeaves);
    regTreeConfig.setMinDataPerLeaf(5);
    regTreeConfig.setNumSplitIntervals(100);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, dataSet);
    TreeRule rule1 = new TreeRule(regressionTree, dataSet.getRow(100));
    TreeRule rule2 = new TreeRule(regressionTree, dataSet.getRow(1));
    ConstantRule rule3 = new ConstantRule(0.8);
    Rule rule4 = new LinearRule();
    List<Rule> rules = new ArrayList<>();
    rules.add(rule1);
    rules.add(rule2);
    rules.add(rule3);
    rules.add(rule4);
    ObjectMapper mapper = new ObjectMapper();
    mapper.writeValue(new File(TMP, "decision.json"), rules);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) ArrayList(java.util.ArrayList) StopWatch(org.apache.commons.lang3.time.StopWatch) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 3 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class LKBInspector method topFeatures.

/**
     *
     * @param lkBoosts ensemble of lktbs
     * @param classIndex
     * @return
     */
public static TopFeatures topFeatures(List<LKBoost> lkBoosts, int classIndex) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    for (LKBoost lkBoost : lkBoosts) {
        List<Regressor> regressors = lkBoost.getEnsemble(classIndex).getRegressors();
        List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
        for (RegressionTree tree : trees) {
            Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
            for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
                Feature feature = entry.getKey();
                Double contribution = entry.getValue();
                double oldValue = totalContributions.getOrDefault(feature, 0.0);
                double newValue = oldValue + contribution;
                totalContributions.put(feature, newValue);
            }
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = lkBoosts.get(0).getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) ClassProbability(edu.neu.ccs.pyramid.classification.ClassProbability) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PredictionAnalysis(edu.neu.ccs.pyramid.classification.PredictionAnalysis) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Example 4 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class LKBInspector method decisionProcess.

public static ClassScoreCalculation decisionProcess(LKBoost boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Example 5 with RegressionTree

use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.

the class LKBInspector method topFeatures.

//todo: consider newton step and learning rate
/**
     * only trees are considered
     * @param boosting
     * @param classIndex
     * @return list of feature index and feature name pairs
     */
public static TopFeatures topFeatures(LKBoost boosting, int classIndex) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
    List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
    for (RegressionTree tree : trees) {
        Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
        for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
            Feature feature = entry.getKey();
            Double contribution = entry.getValue();
            double oldValue = totalContributions.getOrDefault(feature, 0.0);
            double newValue = oldValue + contribution;
            totalContributions.put(feature, newValue);
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = boosting.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) ClassProbability(edu.neu.ccs.pyramid.classification.ClassProbability) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PredictionAnalysis(edu.neu.ccs.pyramid.classification.PredictionAnalysis) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Aggregations

RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)14 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)10 Feature (edu.neu.ccs.pyramid.feature.Feature)6 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)6 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)6 Collectors (java.util.stream.Collectors)6 Vector (org.apache.mahout.math.Vector)6 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)5 java.util (java.util)5 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)4 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)3 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)3 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)3 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)3 Pair (edu.neu.ccs.pyramid.util.Pair)3 IntStream (java.util.stream.IntStream)3 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)2