use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.
the class MLACPlattScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 10; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLACPlattScaling plattScaling = new MLACPlattScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.
the class BRInspector method analyzePrediction.
public static MultiLabelPredictionAnalysis analyzePrediction(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
PredictionCandidate trueCandidate = new PredictionCandidate();
trueCandidate.x = dataSet.getRow(dataPointIndex);
trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
trueCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
PredictionCandidate predictedCandidate = new PredictionCandidate();
predictedCandidate.x = dataSet.getRow(dataPointIndex);
predictedCandidate.multiLabel = predictedLabels;
predictedCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < classProbEstimator.getNumClasses(); k++) {
if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
// todo
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = null;
if (classProbEstimator instanceof IMLGradientBoosting) {
classScoreCalculation = decisionProcess((IMLGradientBoosting) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
}
if (classProbEstimator instanceof CBM) {
classScoreCalculation = decisionProcess((CBM) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
}
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(calibratedClassProbs[label]);
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<Pair<MultiLabel, Double>> topK;
if (classifier instanceof SupportPredictor) {
topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
} else {
topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
}
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
MultiLabel multiLabel = pair.getFirst();
double setProb = pair.getSecond();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.
the class IMLGBInspection method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
IMLGradientBoosting boosting = (IMLGradientBoosting) Serialization.deserialize(config.getString("model"));
System.out.println("average number of features selected by each binary classifier = " + IntStream.range(0, boosting.getNumClasses()).mapToDouble(l -> IMLGBInspector.getSelectedFeatures(boosting, l).size()).average().getAsDouble());
System.out.println("total number of features selected (union) = " + IMLGBInspector.getSelectedFeatures(boosting).size());
}
use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.
the class MLPlattScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
boosting.setAssignments(assignments);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.
the class MLFlatScalingTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).build();
IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 100; round++) {
System.out.println("round=" + round);
trainer.iterate();
System.out.println(stopWatch);
}
MLFlatScaling scaling = new MLFlatScaling(dataSet, boosting);
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
System.out.println(Arrays.toString(scaling.predictClassProbs(dataSet.getRow(i))));
System.out.println("======================");
}
}
Aggregations