Search in sources :

Example 36 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMPredictor method predictByDynamic.

// todo fix
//    public MultiLabel predictByGreedy() {
//        Vector predVector = new DenseVector(numLabels);
//        double prevLogProb;
//        if (allowEmpty) {
//            prevLogProb = logProbYnGivenXnLogisticProb(predVector);
//        } else {
//            prevLogProb = Double.NEGATIVE_INFINITY;
//        }
//
//        Set<Integer> labelSet = IntStream.range(0, numLabels).boxed().collect(Collectors.toSet());
//        while (!labelSet.isEmpty()) {
//            double curLogProb = Double.NEGATIVE_INFINITY;
//            int curLebel=-1;
//            // find the maximum one
//            for (int k : labelSet) {
//                Vector curVector = new DenseVector((DenseVector) predVector, false);
//                curVector.set(k, 1.0);
//                double logProb = logProbYnGivenXnLogisticProb(curVector);
//                if (logProb > curLogProb) {
//                    curLebel = k;
//                    curLogProb = logProb;
//                }
//            }
//
//            if (curLogProb > prevLogProb) {
//                predVector.set(curLebel, 1.0);
//                prevLogProb = curLogProb;
//                labelSet.remove(curLebel);
//            } else {
//                break;
//            }
//        }
//
//        MultiLabel predLabel = new MultiLabel();
//        for (int l=0; l<numLabels; l++) {
//            if (predVector.get(l) == 1.0) {
//                predLabel.addLabel(l);
//            }
//        }
//        return predLabel;
//    }
public MultiLabel predictByDynamic() {
    // initialization
    Map<Integer, DynamicProgramming> DPs = new HashMap<>();
    double[] maxClusterProb = new double[numClusters];
    for (int k = 0; k < numClusters; k++) {
        DPs.put(k, new DynamicProgramming(probs[k], logProbs[k]));
        maxClusterProb[k] = DPs.get(k).nextHighestProb();
    }
    // speed up:
    // 1) for pi^k (D^k - q) >= 1 - pi^k
    double[] cond1 = new double[numClusters];
    // 2) save condition for sum_{r!=k} (pi^r * D^r)
    double[] sumPiD = new double[numClusters];
    for (int k = 0; k < numClusters; k++) {
        cond1[k] = maxClusterProb[k] - 1.0 / logisticProb[k] + 1;
        double sum = 0.0;
        for (int r = 0; r < numClusters; r++) {
            if (r == k) {
                continue;
            }
            sum += logisticProb[r] * maxClusterProb[r];
        }
        sumPiD[k] = sum;
    }
    double maxLogProb = Double.NEGATIVE_INFINITY;
    MultiLabel bestMultiLabel = new MultiLabel();
    int iter = 0;
    int maxIter = 10;
    while (DPs.size() > 0) {
        List<Integer> removeList = new LinkedList<>();
        for (Map.Entry<Integer, DynamicProgramming> entry : DPs.entrySet()) {
            int k = entry.getKey();
            DynamicProgramming dp = entry.getValue();
            double prob = dp.nextHighestProb();
            MultiLabel multiLabel = dp.nextHighestVector();
            // whether consider empty prediction
            if ((multiLabel.getNumMatchedLabels() == 0) && !allowEmpty) {
                if (dp.getQueue().size() == 0) {
                    removeList.add(k);
                }
                continue;
            }
            double logProb = logProbYnGivenXnLogisticProb(multiLabel);
            if (logProb >= maxLogProb) {
                bestMultiLabel = multiLabel;
                maxLogProb = logProb;
            //                    maxIter = iter;
            }
            // check if need to remove cluster k from the candidates
            if (checkStop(prob, cond1[k], maxLogProb, sumPiD[k], k) || dp.getQueue().size() == 0) {
                removeList.add(k);
            }
        }
        for (int k : removeList) {
            DPs.remove(k);
        }
        iter++;
        if (iter >= maxIter) {
            break;
        }
    }
    return bestMultiLabel;
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 37 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMS method predictByMarginals.

/**
     * predict Sign(E(y|x))
     * @param vector
     * @return
     */
MultiLabel predictByMarginals(Vector vector) {
    double[] probs = predictClassProbs(vector);
    MultiLabel prediction = new MultiLabel();
    for (int l = 0; l < numLabels; l++) {
        if (probs[l] > 0.5) {
            prediction.addLabel(l);
        }
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 38 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMNoiseOptimizerFixed method updateGamma.

private void updateGamma(int n) {
    Vector x = dataSet.getRow(n);
    BMDistribution bmDistribution = cbm.computeBM(x);
    // size = combination * components
    List<double[]> logPosteriors = new ArrayList<>();
    for (int c = 0; c < combinations.size(); c++) {
        MultiLabel combination = combinations.get(c);
        double[] pos = bmDistribution.logPosteriorMembership(combination);
        logPosteriors.add(pos);
    }
    double[] sums = new double[cbm.numComponents];
    for (int k = 0; k < cbm.numComponents; k++) {
        double sum = 0;
        for (int c = 0; c < combinations.size(); c++) {
            sum += targets[n][c] * logPosteriors.get(c)[k];
        }
        sums[k] = sum;
    }
    double[] posterior = MathUtil.softmax(sums);
    for (int k = 0; k < cbm.numComponents; k++) {
        gammas[n][k] = posterior[k];
        gammasT[k][n] = posterior[k];
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) ArrayList(java.util.ArrayList) Vector(org.apache.mahout.math.Vector)

Example 39 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBM method predictLogAssignmentProbs.

/**
     * for batch jobs, use this to save computation
     * @param x
     * @param assignments
     * @return
     */
public double[] predictLogAssignmentProbs(Vector x, List<MultiLabel> assignments, double piThreshold) {
    BMDistribution bmDistribution = new BMDistribution(this, x, piThreshold);
    double[] probs = new double[assignments.size()];
    for (int c = 0; c < assignments.size(); c++) {
        MultiLabel multiLabel = assignments.get(c);
        probs[c] = bmDistribution.logProbability(multiLabel);
    }
    return probs;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 40 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBMSOptimizer method updateGamma.

private void updateGamma(int n) {
    Vector x = dataSet.getRow(n);
    MultiLabel y = dataSet.getMultiLabels()[n];
    double[] posterior = cbms.posteriorMembership(x, y);
    for (int k = 0; k < cbms.numComponents; k++) {
        gammas[n][k] = posterior[k];
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3