Search in sources :

Example 1 with Tuple5

use of scala.Tuple5 in project angel by Tencent.

the class ValidationUtils method calMetrics.

/**
 * validate loss, AUC and precision
 *
 * @param dataBlock:  validation data taskDataBlock
 * @param weight:   the weight vector of features
 * @param lossFunc: the lossFunc used for prediction
 */
public static Tuple5<Double, Double, Double, Double, Double> calMetrics(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
    dataBlock.resetReadIndex();
    int totalNum = dataBlock.size();
    LOG.debug("Start calculate loss and auc, sample number: " + totalNum);
    long startTime = System.currentTimeMillis();
    double loss = 0.0;
    double[] scoresArray = new double[totalNum];
    double[] labelsArray = new double[totalNum];
    // ground truth: positive, precision: positive
    int truePos = 0;
    // ground truth: negative, precision: positive
    int falsePos = 0;
    // ground truth: negative, precision: negative
    int trueNeg = 0;
    // ground truth: positive, precision: negative
    int falseNeg = 0;
    for (int i = 0; i < totalNum; i++) {
        LabeledData data = dataBlock.read();
        double pre = lossFunc.predict(weight, data.getX());
        if (pre * data.getY() > 0) {
            if (pre > 0) {
                truePos++;
            } else {
                trueNeg++;
            }
        } else if (pre * data.getY() < 0) {
            if (pre > 0) {
                falsePos++;
            } else {
                falseNeg++;
            }
        }
        scoresArray[i] = pre;
        labelsArray[i] = data.getY();
        loss += lossFunc.loss(pre, data.getY());
    }
    loss += lossFunc.getReg(weight);
    double precision = (double) (truePos + trueNeg) / totalNum;
    Tuple3<Double, Double, Double> tuple3 = calAUC(scoresArray, labelsArray, truePos, trueNeg, falsePos, falseNeg);
    double aucResult = tuple3._1();
    double trueRecall = tuple3._2();
    double falseRecall = tuple3._3();
    return new Tuple5(loss, precision, aucResult, trueRecall, falseRecall);
}
Also used : Tuple5(scala.Tuple5) LabeledData(com.tencent.angel.ml.feature.LabeledData)

Aggregations

LabeledData (com.tencent.angel.ml.feature.LabeledData)1 Tuple5 (scala.Tuple5)1