Search in sources :

Example 1 with TreeRule

use of edu.neu.ccs.pyramid.regression.regression_tree.TreeRule in project pyramid by cheng-li.

the class RulesTest method test1.

static void test1() throws Exception {
    int numLeaves = 4;
    RegDataSet dataSet = StandardFormat.loadRegDataSet("/Users/chengli/Datasets/slice_location/standard/featureList.txt", "/Users/chengli/Datasets/slice_location/standard/labels.txt", ",", DataSetType.REG_DENSE, false);
    System.out.println(dataSet.isDense());
    int[] activeFeatures = IntStream.range(0, dataSet.getNumFeatures()).toArray();
    int[] activeDataPoints = IntStream.range(0, dataSet.getNumDataPoints()).toArray();
    RegTreeConfig regTreeConfig = new RegTreeConfig();
    regTreeConfig.setMaxNumLeaves(numLeaves);
    regTreeConfig.setMinDataPerLeaf(5);
    regTreeConfig.setNumSplitIntervals(100);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, dataSet);
    TreeRule rule1 = new TreeRule(regressionTree, dataSet.getRow(100));
    TreeRule rule2 = new TreeRule(regressionTree, dataSet.getRow(1));
    ConstantRule rule3 = new ConstantRule(0.8);
    Rule rule4 = new LinearRule();
    List<Rule> rules = new ArrayList<>();
    rules.add(rule1);
    rules.add(rule2);
    rules.add(rule3);
    rules.add(rule4);
    ObjectMapper mapper = new ObjectMapper();
    mapper.writeValue(new File(TMP, "decision.json"), rules);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) ArrayList(java.util.ArrayList) StopWatch(org.apache.commons.lang3.time.StopWatch) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 2 with TreeRule

use of edu.neu.ccs.pyramid.regression.regression_tree.TreeRule in project pyramid by cheng-li.

the class LKBInspector method decisionProcess.

public static ClassScoreCalculation decisionProcess(LKBoost boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Example 3 with TreeRule

use of edu.neu.ccs.pyramid.regression.regression_tree.TreeRule in project pyramid by cheng-li.

the class HMLGBInspector method decisionProcess.

//
//    public static String analyzeMistake(HMLGradientBoosting boosting, Vector vector,
//                                        MultiLabel trueLabel, MultiLabel prediction,
//                                        LabelTranslator labelTranslator, int limit){
//        StringBuilder sb = new StringBuilder();
//        List<Integer> difference = MultiLabel.symmetricDifference(trueLabel,prediction).stream().sorted().collect(Collectors.toList());
//
//        double[] classScores = boosting.predictClassScores(vector);
//        sb.append("score for the true labels ").append(trueLabel)
//                .append("(").append(trueLabel.toStringWithExtLabels(labelTranslator)).append(") = ");
//        sb.append(boosting.calAssignmentScore(trueLabel,classScores)).append("\n");
//
//        sb.append("score for the predicted labels ").append(prediction)
//                .append("(").append(prediction.toStringWithExtLabels(labelTranslator)).append(") = ");;
//        sb.append(boosting.calAssignmentScore(prediction,classScores)).append("\n");
//
//        for (int k: difference){
//            sb.append("score for class ").append(k).append("(").append(labelTranslator.toExtLabel(k)).append(")")
//                    .append(" =").append(classScores[k]).append("\n");
//        }
//
//        for (int k: difference){
//            sb.append("decision process for class ").append(k).append("(").append(labelTranslator.toExtLabel(k)).append("):\n");
//            sb.append(decisionProcess(boosting,vector,k,limit));
//            sb.append("--------------------------------------------------").append("\n");
//        }
//
//        return sb.toString();
//    }
public static ClassScoreCalculation decisionProcess(HMLGradientBoosting boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    double prob = boosting.predictClassProb(vector, classIndex);
    classScoreCalculation.setClassProbability(prob);
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Example 4 with TreeRule

use of edu.neu.ccs.pyramid.regression.regression_tree.TreeRule in project pyramid by cheng-li.

the class IMLGBInspector method decisionProcess.

public static ClassScoreCalculation decisionProcess(IMLGradientBoosting boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    double prob = boosting.predictClassProb(vector, classIndex);
    classScoreCalculation.setClassProbability(prob);
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Example 5 with TreeRule

use of edu.neu.ccs.pyramid.regression.regression_tree.TreeRule in project pyramid by cheng-li.

the class AdaBoostMHInspector method decisionProcess.

public static ClassScoreCalculation decisionProcess(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
    ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
    double prob = scaling.predictClassProb(vector, classIndex);
    classScoreCalculation.setClassProbability(prob);
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<TreeRule> treeRules = new ArrayList<>();
    for (Regressor regressor : regressors) {
        if (regressor instanceof ConstantRegressor) {
            Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
            classScoreCalculation.addRule(rule);
        }
        if (regressor instanceof RegressionTree) {
            RegressionTree tree = (RegressionTree) regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
    }
    Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
    List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    for (TreeRule treeRule : merged) {
        classScoreCalculation.addRule(treeRule);
    }
    return classScoreCalculation;
}
Also used : TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)

Aggregations

RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)5 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)5 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)1 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)1 File (java.io.File)1 ArrayList (java.util.ArrayList)1 StopWatch (org.apache.commons.lang3.time.StopWatch)1