use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.
the class AdaBoostMHTrainer method fitClassK.
private RegressionTree fitClassK(int k) {
double[] probs = weightMatrix.getProbsForClass(k);
double match = IntStream.range(0, dataSet.getNumDataPoints()).parallel().filter(i -> labels[i][k]).mapToDouble(i -> weightMatrix.getProbsForData(i)[k]).sum();
double notMatch = IntStream.range(0, dataSet.getNumDataPoints()).parallel().filter(i -> !labels[i][k]).mapToDouble(i -> weightMatrix.getProbsForData(i)[k]).sum();
StumpInfo optimal = IntStream.range(0, dataSet.getNumFeatures()).parallel().mapToObj(j -> {
double matchOccur = 0;
double notMatchOccur = 0;
Vector vector = dataSet.getColumn(j);
for (Vector.Element element : vector.nonZeroes()) {
int i = element.index();
double prob = probs[i];
if (labels[i][k]) {
matchOccur += prob;
} else {
notMatchOccur += prob;
}
}
double matchNotOccur = match - matchOccur;
double notMatchNotOccur = notMatch - notMatchOccur;
StumpInfo stumpInfo = new StumpInfo();
stumpInfo.featureIndex = j;
stumpInfo.matchOccur = matchOccur;
stumpInfo.matchNotOccur = matchNotOccur;
stumpInfo.notMatchOccur = notMatchOccur;
stumpInfo.notMatchNotOccur = notMatchNotOccur;
return stumpInfo;
}).min(Comparator.comparing(StumpInfo::getObjective)).get();
double smooth = 1.0 / (dataSet.getNumDataPoints() * dataSet.getNumClasses());
double leftOutput = 0.5 * Math.log((optimal.matchNotOccur + smooth) / (optimal.notMatchNotOccur + smooth));
double rightOutput = 0.5 * Math.log((optimal.matchOccur + smooth) / (optimal.notMatchOccur + smooth));
RegressionTree tree = RegressionTree.newStump(optimal.featureIndex, 0, leftOutput, rightOutput);
tree.setFeatureList(dataSet.getFeatureList());
return tree;
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.
the class HMLGBInspector method decisionProcess.
//
// public static String analyzeMistake(HMLGradientBoosting boosting, Vector vector,
// MultiLabel trueLabel, MultiLabel prediction,
// LabelTranslator labelTranslator, int limit){
// StringBuilder sb = new StringBuilder();
// List<Integer> difference = MultiLabel.symmetricDifference(trueLabel,prediction).stream().sorted().collect(Collectors.toList());
//
// double[] classScores = boosting.predictClassScores(vector);
// sb.append("score for the true labels ").append(trueLabel)
// .append("(").append(trueLabel.toStringWithExtLabels(labelTranslator)).append(") = ");
// sb.append(boosting.calAssignmentScore(trueLabel,classScores)).append("\n");
//
// sb.append("score for the predicted labels ").append(prediction)
// .append("(").append(prediction.toStringWithExtLabels(labelTranslator)).append(") = ");;
// sb.append(boosting.calAssignmentScore(prediction,classScores)).append("\n");
//
// for (int k: difference){
// sb.append("score for class ").append(k).append("(").append(labelTranslator.toExtLabel(k)).append(")")
// .append(" =").append(classScores[k]).append("\n");
// }
//
// for (int k: difference){
// sb.append("decision process for class ").append(k).append("(").append(labelTranslator.toExtLabel(k)).append("):\n");
// sb.append(decisionProcess(boosting,vector,k,limit));
// sb.append("--------------------------------------------------").append("\n");
// }
//
// return sb.toString();
// }
public static ClassScoreCalculation decisionProcess(HMLGradientBoosting boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
double prob = boosting.predictClassProb(vector, classIndex);
classScoreCalculation.setClassProbability(prob);
List<Regressor> regressors = boosting.getRegressors(classIndex);
List<TreeRule> treeRules = new ArrayList<>();
for (Regressor regressor : regressors) {
if (regressor instanceof ConstantRegressor) {
Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
classScoreCalculation.addRule(rule);
}
if (regressor instanceof RegressionTree) {
RegressionTree tree = (RegressionTree) regressor;
TreeRule treeRule = new TreeRule(tree, vector);
treeRules.add(treeRule);
}
}
Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
for (TreeRule treeRule : merged) {
classScoreCalculation.addRule(treeRule);
}
return classScoreCalculation;
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.
the class IMLGBInspector method topFeatures.
//todo: consider newton step and learning rate
/**
* only trees are considered
* @param boosting
* @param classIndex
* @return list of feature index and feature name pairs
*/
public static TopFeatures topFeatures(IMLGradientBoosting boosting, int classIndex, int limit) {
Map<Feature, Double> totalContributions = new HashMap<>();
List<Regressor> regressors = boosting.getRegressors(classIndex);
List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
for (RegressionTree tree : trees) {
Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
Feature feature = entry.getKey();
Double contribution = entry.getValue();
double oldValue = totalContributions.getOrDefault(feature, 0.0);
double newValue = oldValue + contribution;
totalContributions.put(feature, newValue);
}
}
Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = boosting.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.
the class HMLGBTrainer method fitClassK.
/**
* parallel
* find the best regression tree for class k
* apply newton step and learning rate
* @param k class index
* @return regressionTreeLk, shrunk
* @throws Exception
*/
private RegressionTree fitClassK(int k) {
double[] gradients = gradientMatrix.getGradientsForClass(k);
int numClasses = this.config.getDataSet().getNumClasses();
double learningRate = this.config.getLearningRate();
LeafOutputCalculator leafOutputCalculator = new HMLGBLeafOutputCalculator(numClasses);
RegTreeConfig regTreeConfig = new RegTreeConfig();
regTreeConfig.setMaxNumLeaves(this.config.getNumLeaves());
regTreeConfig.setMinDataPerLeaf(this.config.getMinDataPerLeaf());
regTreeConfig.setNumSplitIntervals(this.config.getNumSplitIntervals());
RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, this.config.getDataSet(), gradients, leafOutputCalculator);
regressionTree.shrink(learningRate);
return regressionTree;
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree in project pyramid by cheng-li.
the class IMLGBInspector method decisionProcess.
public static ClassScoreCalculation decisionProcess(IMLGradientBoosting boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
double prob = boosting.predictClassProb(vector, classIndex);
classScoreCalculation.setClassProbability(prob);
List<Regressor> regressors = boosting.getRegressors(classIndex);
List<TreeRule> treeRules = new ArrayList<>();
for (Regressor regressor : regressors) {
if (regressor instanceof ConstantRegressor) {
Rule rule = new ConstantRule(((ConstantRegressor) regressor).getScore());
classScoreCalculation.addRule(rule);
}
if (regressor instanceof RegressionTree) {
RegressionTree tree = (RegressionTree) regressor;
TreeRule treeRule = new TreeRule(tree, vector);
treeRules.add(treeRule);
}
}
Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
List<TreeRule> merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
for (TreeRule treeRule : merged) {
classScoreCalculation.addRule(treeRule);
}
return classScoreCalculation;
}
Aggregations