Search in sources :

Example 1 with IMLGradientBoosting

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.

the class MLACPlattScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 10; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLACPlattScaling plattScaling = new MLACPlattScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 2 with IMLGradientBoosting

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.

the class BRInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
    double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
    PredictionCandidate trueCandidate = new PredictionCandidate();
    trueCandidate.x = dataSet.getRow(dataPointIndex);
    trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
    trueCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
    MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    PredictionCandidate predictedCandidate = new PredictionCandidate();
    predictedCandidate.x = dataSet.getRow(dataPointIndex);
    predictedCandidate.multiLabel = predictedLabels;
    predictedCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < classProbEstimator.getNumClasses(); k++) {
        if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    // todo
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = null;
        if (classProbEstimator instanceof IMLGradientBoosting) {
            classScoreCalculation = decisionProcess((IMLGradientBoosting) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        }
        if (classProbEstimator instanceof CBM) {
            classScoreCalculation = decisionProcess((CBM) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        }
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(calibratedClassProbs[label]);
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<Pair<MultiLabel, Double>> topK;
    if (classifier instanceof SupportPredictor) {
        topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
    } else {
        topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
    }
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
        MultiLabel multiLabel = pair.getFirst();
        double setProb = pair.getSecond();
        MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
        return labelSetProbInfo;
    }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Vector(org.apache.mahout.math.Vector) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) ArrayList(java.util.ArrayList) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) Pair(edu.neu.ccs.pyramid.util.Pair) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)

Example 3 with IMLGradientBoosting

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.

the class IMLGBInspection method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    IMLGradientBoosting boosting = (IMLGradientBoosting) Serialization.deserialize(config.getString("model"));
    System.out.println("average number of features selected by each binary classifier = " + IntStream.range(0, boosting.getNumClasses()).mapToDouble(l -> IMLGBInspector.getSelectedFeatures(boosting, l).size()).average().getAsDouble());
    System.out.println("total number of features selected (union) = " + IMLGBInspector.getSelectedFeatures(boosting).size());
}
Also used : IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) Config(edu.neu.ccs.pyramid.configuration.Config)

Example 4 with IMLGradientBoosting

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.

the class MLPlattScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 100; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 5 with IMLGradientBoosting

use of edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting in project pyramid by cheng-li.

the class MLFlatScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 100; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLFlatScaling scaling = new MLFlatScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(scaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : MLFlatScaling(edu.neu.ccs.pyramid.multilabel_classification.MLFlatScaling) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Aggregations

IMLGradientBoosting (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)5 IMLGBConfig (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig)3 IMLGBTrainer (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer)3 File (java.io.File)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 edu.neu.ccs.pyramid.calibration (edu.neu.ccs.pyramid.calibration)1 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)1 Config (edu.neu.ccs.pyramid.configuration.Config)1 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)1 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)1 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 Feature (edu.neu.ccs.pyramid.feature.Feature)1 MLFlatScaling (edu.neu.ccs.pyramid.multilabel_classification.MLFlatScaling)1 CBM (edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM)1 SupportPredictor (edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor)1 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)1 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)1 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)1