Search in sources :

Example 26 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class MultiLabelSynthesizer method flipOneNonUniform.

/**
     * y0: w=(0,1)
     * y1: w=(1,1)
     * y2: w=(1,0)
     * y3: w=(1,-1)
     * @param numData
     * @return
     */
public static MultiLabelClfDataSet flipOneNonUniform(int numData) {
    int numClass = 4;
    int numFeature = 2;
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    // generate weights
    Vector[] weights = new Vector[numClass];
    for (int k = 0; k < numClass; k++) {
        Vector vector = new DenseVector(numFeature);
        weights[k] = vector;
    }
    weights[0].set(0, 0);
    weights[0].set(1, 1);
    weights[1].set(0, 1);
    weights[1].set(1, 1);
    weights[2].set(0, 1);
    weights[2].set(1, 0);
    weights[3].set(0, 1);
    weights[3].set(1, -1);
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    // assign labels
    for (int i = 0; i < numData; i++) {
        for (int k = 0; k < numClass; k++) {
            double dot = weights[k].dot(dataSet.getRow(i));
            if (dot >= 0) {
                dataSet.addLabel(i, k);
            }
        }
    }
    int[] indices = { 0, 1, 2, 3 };
    double[] probs = { 0.4, 0.2, 0.2, 0.2 };
    IntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, probs);
    // flip
    for (int i = 0; i < numData; i++) {
        int toChange = distribution.sample();
        MultiLabel label = dataSet.getMultiLabels()[i];
        if (label.matchClass(toChange)) {
            label.removeLabel(toChange);
        } else {
            label.addLabel(toChange);
        }
    }
    return dataSet;
}
Also used : EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) IntegerDistribution(org.apache.commons.math3.distribution.IntegerDistribution) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) DenseVector(org.apache.mahout.math.DenseVector)

Example 27 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class Meka2Trec method main.

/**
     * this is only support multi-label classification dataset.
     * @param args
     */
public static void main(String[] args) throws IOException {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    List<String> trecs = config.getStrings("trec");
    List<String> mekas = config.getStrings("meka");
    int numLabels = config.getInt("numLabels");
    int numFeatures = config.getInt("numFeatures");
    String dataMode = config.getString("dataMode");
    for (int i = 0; i < mekas.size(); i++) {
        System.out.println("processing on: " + trecs.get(i));
        MultiLabelClfDataSet dataSet = MekaFormat.loadMLClfDataset(mekas.get(i), numFeatures, numLabels, dataMode);
        TRECFormat.save(dataSet, trecs.get(i));
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 28 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class AdaBoostMHInspector method analyzePrediction.

/**
     * can be binary scaling or across-class scaling
     * @param boosting
     * @param scaling
     * @param dataSet
     * @param dataPointIndex
     * @param classes
     * @param limit
     * @return
     */
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double probForTrueLabels = Double.NaN;
    if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
        probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
    }
    predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
    MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    double probForPredictedLabels = Double.NaN;
    if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
        probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
    }
    predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
        return rankInfo;
    }).collect(Collectors.toList());
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLPlattScaling) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MLACPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLACPlattScaling) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PlattScaling(edu.neu.ccs.pyramid.classification.PlattScaling) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator)

Example 29 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class SparkCBMOptimizer method updateBinaryClassifiers.

private void updateBinaryClassifiers() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateBinaryClassifiers");
    }
    Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
    double[][] localGammasT = gammasT;
    Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
    Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
    double localVariance = priorVarianceBinary;
    List<BinaryTask> binaryTaskList = new ArrayList<>();
    for (int k = 0; k < cbm.numComponents; k++) {
        for (int l = 0; l < cbm.numLabels; l++) {
            LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
            double[] weights = localGammasT[k];
            binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
        }
    }
    JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
    List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
        int labelIndex = binaryTask.classIndex;
        return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
    }).collect();
    for (BinaryTaskResult result : results) {
        cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
    }
    //        IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateBinaryClassifiers");
    }
}
Also used : IntStream(java.util.stream.IntStream) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost) LogisticLoss(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss) JavaRDD(org.apache.spark.api.java.JavaRDD) Broadcast(org.apache.spark.broadcast.Broadcast) LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Serializable(java.io.Serializable) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) Logger(org.apache.logging.log4j.Logger) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) Entropy(edu.neu.ccs.pyramid.eval.Entropy) Vector(org.apache.mahout.math.Vector) LogManager(org.apache.logging.log4j.LogManager) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)

Example 30 with MultiLabelClfDataSet

use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.

the class CBMTest method test3.

private static void test3() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    int numComponents = 4;
    CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("lr").setMultiClassClassifierType("boost").build();
    cbm.setPredictMode("dynamic");
    CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
    optimizer.setPriorVarianceBinary(10);
    optimizer.setPriorVarianceMultiClass(10);
    CBMInitializer.initialize(cbm, dataSet, optimizer);
    cbm.setNumSample(100);
    System.out.println("num cluster: " + cbm.numComponents);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(cbm, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(cbm, testSet));
    for (int i = 1; i <= 30; i++) {
        optimizer.iterate();
        System.out.print("iter : " + i + "\t");
        System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
        System.out.print("trainAcc : " + Accuracy.accuracy(cbm, dataSet) + "\t");
        System.out.print("trainOver: " + Overlap.overlap(cbm, dataSet) + "\t");
        System.out.print("testAcc  : " + Accuracy.accuracy(cbm, testSet) + "\t");
        System.out.println("testOver : " + Overlap.overlap(cbm, testSet) + "\t");
    }
    System.out.println("history = " + optimizer.getTerminator().getHistory());
    System.out.println(cbm);
}
Also used : File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)48 File (java.io.File)24 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)23 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)12 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)9 Vector (org.apache.mahout.math.Vector)9 Config (edu.neu.ccs.pyramid.configuration.Config)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 DenseVector (org.apache.mahout.math.DenseVector)7 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 java.util (java.util)5 Collectors (java.util.stream.Collectors)5 IntStream (java.util.stream.IntStream)5 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)4 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)4 MLScorer (edu.neu.ccs.pyramid.multilabel_classification.MLScorer)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 AccScorer (edu.neu.ccs.pyramid.multilabel_classification.AccScorer)3