Search in sources :

Example 31 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMInspector method visualizeClusters.

public static List<Map<MultiLabel, Double>> visualizeClusters(CBM bmm, MultiLabelClfDataSet dataSet) {
    int numClusters = bmm.getNumComponents();
    List<Map<MultiLabel, Double>> list = new ArrayList<>();
    for (int k = 0; k < numClusters; k++) {
        list.add(new HashMap<>());
    }
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        double[] clusterProbs = bmm.getMultiClassClassifier().predictClassProbs(dataSet.getRow(i));
        MultiLabel multiLabel = dataSet.getMultiLabels()[i];
        for (int k = 0; k < numClusters; k++) {
            Map<MultiLabel, Double> map = list.get(k);
            double count = map.getOrDefault(multiLabel, 0.0);
            double newcount = count + clusterProbs[k];
            map.put(multiLabel, newcount);
        }
    }
    return list;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 32 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMInspector method visualizePrediction.

public static void visualizePrediction(CBM CBM, Vector vector, LabelTranslator labelTranslator) {
    int numClusters = CBM.getNumComponents();
    int numClasses = CBM.getNumClasses();
    double[] proportions = CBM.getMultiClassClassifier().predictClassProbs(vector);
    double[][] probabilities = new double[numClusters][numClasses];
    for (int k = 0; k < numClusters; k++) {
        for (int l = 0; l < numClasses; l++) {
            probabilities[k][l] = CBM.getBinaryClassifiers()[k][l].predictClassProb(vector, 1);
        }
    }
    int[] sorted = ArgSort.argSortDescending(proportions);
    double[] topProbs = probabilities[sorted[0]];
    MultiLabel trivalPred = new MultiLabel();
    for (int l = 0; l < numClasses; l++) {
        if (topProbs[l] >= 0.5) {
            trivalPred.addLabel(l);
        }
    }
    MultiLabel secondPred = new MultiLabel();
    for (int l = 0; l < numClasses; l++) {
        if (probabilities[sorted[1]][l] >= 0.5) {
            secondPred.addLabel(l);
        }
    }
    MultiLabel predicted = CBM.predict(vector);
    if (!predicted.equals(trivalPred)) {
        System.out.println("interesting case !");
        if (!predicted.equals(secondPred)) {
            System.out.println("very interesting case !");
        }
    }
    //        System.out.println("proportion = "+ Arrays.toString(proportions));
    double[] sortedPorportions = new double[numClusters];
    for (int t = 0; t < sorted.length; t++) {
        int k = sorted[t];
        sortedPorportions[t] = proportions[k];
    }
    System.out.println("proportions = " + Arrays.toString(sortedPorportions));
    for (int t = 0; t < sorted.length; t++) {
        int k = sorted[t];
        System.out.println("prob" + (t + 1) + " = " + Arrays.toString(probabilities[k]));
    }
    for (int t = 0; t < sorted.length; t++) {
        int k = sorted[t];
        double[] probs = probabilities[k];
        List<String> labels = new ArrayList<>();
        for (int l = 0; l < numClasses; l++) {
            if (probs[l] > 0.5) {
                labels.add("\"" + labelTranslator.toExtLabel(l) + "\"");
            }
        }
        System.out.println("labels" + (t + 1) + " = " + labels);
    }
    System.out.println("perplexity=" + Math.pow(2, Entropy.entropy2Based(proportions)));
    for (int t = 0; t < numClusters; t++) {
        System.out.println("cluster" + (t + 1) + " = " + Arrays.toString(probabilities[t]));
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 33 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MLPriorProbClassifier method fit.

public void fit(MultiLabelClfDataSet dataSet) {
    if (dataSet.getNumClasses() != this.numClasses) {
        throw new IllegalArgumentException("dataSet.getNumClasses()!=this.numClasses");
    }
    int[] counts = new int[this.numClasses];
    MultiLabel[] multiLabels = dataSet.getMultiLabels();
    for (MultiLabel multiLabel : multiLabels) {
        for (int matchedClass : multiLabel.getMatchedLabels()) {
            counts[matchedClass] += 1;
        }
    }
    for (int k = 0; k < this.numClasses; k++) {
        this.classProbs[k] = ((double) counts[k]) / dataSet.getNumDataPoints();
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 34 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMOptimizer method updateGamma.

private void updateGamma(int n) {
    Vector x = dataSet.getRow(n);
    MultiLabel y = dataSet.getMultiLabels()[n];
    double[] posterior = cbm.posteriorMembership(x, y, noiseLabelWeights[n]);
    for (int k = 0; k < cbm.numComponents; k++) {
        gammas[n][k] = posterior[k];
        gammasT[k][n] = posterior[k];
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector)

Example 35 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMPredictor method predictByHardAssignment.

//todo fix
//    /**
//     * predict by sampling.
//     * @return
//     */
//    public MultiLabel predictBySampling() {
//
//        double maxLogProb = Double.NEGATIVE_INFINITY;
//        Vector predVector = new DenseVector(numLabels);
//
//        int[] clusters = IntStream.range(0, numClusters).toArray();
//        EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(clusters, logisticProb);
//
//        for (int s=0; s<numSample; s++) {
//            int k = enumeratedIntegerDistribution.sample();
//            Vector candidateY = new DenseVector(numLabels);
//
//            for (int l=0; l<numLabels; l++) {
//                BernoulliDistribution bernoulliDistribution = new BernoulliDistribution(probs[k][l][1]);
//                candidateY.set(l, bernoulliDistribution.sample());
//            }
//            // whether consider empty prediction
//            if ((candidateY.maxValue() == 0.0) && !allowEmpty) {
//                continue;
//            }
//
//            double logProb = logProbYnGivenXnLogisticProb(candidateY);
//
//            if (logProb >= maxLogProb) {
//                predVector = candidateY;
//                maxLogProb = logProb;
//            }
//        }
//
//        MultiLabel predLabel = new MultiLabel();
//        for (int l=0; l<numLabels; l++) {
//            if (predVector.get(l) == 1.0) {
//                predLabel.addLabel(l);
//            }
//        }
//        return predLabel;
//    }
/**
     * Predict by hard assignment for just one cluster with maximum prob.
     * @param
     * @return MultiLabl
     */
public MultiLabel predictByHardAssignment() {
    // find the max cluster
    int maxK = 0;
    double maxPi = logisticLogProb[0];
    for (int k = 1; k < logisticLogProb.length; k++) {
        if (maxPi < logisticLogProb[k]) {
            maxK = k;
            maxPi = logisticLogProb[k];
        }
    }
    MultiLabel predict = new MultiLabel();
    for (int l = 0; l < numLabels; l++) {
        if (probs[maxK][l][1] > 0.5) {
            predict.addLabel(l);
        }
    }
    return predict;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3