Search in sources :

Example 1 with LabeledData

use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.

the class ValidationUtils method calLossPrecision.

/**
 * validate loss and precision
 *
 * @param dataBlock:  validation data taskDataBlock
 * @param weight:   the weight vector of features
 * @param lossFunc: the lossFunc used for prediction
 */
public static double calLossPrecision(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
    dataBlock.resetReadIndex();
    long startTime = System.currentTimeMillis();
    int totalNum = dataBlock.size();
    double loss = 0.0;
    // ground truth: positive, precision: positive
    int truePos = 0;
    // ground truth: negative, precision: positive
    int falsePos = 0;
    // ground truth: negative, precision: negative
    int trueNeg = 0;
    // ground truth: positive, precision: negative
    int falseNeg = 0;
    for (int i = 0; i < totalNum; i++) {
        LabeledData data = dataBlock.get(i);
        double pre = lossFunc.predict(weight, data.getX());
        if (pre * data.getY() > 0) {
            if (pre > 0) {
                truePos++;
            } else {
                trueNeg++;
            }
        } else if (pre * data.getY() < 0) {
            if (pre > 0) {
                falsePos++;
            } else {
                falseNeg++;
            }
        }
        loss += lossFunc.loss(pre, data.getY());
    }
    long cost = System.currentTimeMillis() - startTime;
    double precision = (double) (truePos + trueNeg) / totalNum;
    double trueRecall = (double) truePos / (truePos + falseNeg);
    double falseRecall = (double) trueNeg / (trueNeg + falsePos);
    LOG.debug(String.format("validate cost %d ms, loss= %.5f, precision=%.5f, trueRecall=%.5f, " + "falseRecall=%.5f", totalNum, cost, loss, precision, trueRecall, falseRecall));
    LOG.debug(String.format("Validation TP=%d, TN=%d, FP=%d, FN=%d", truePos, trueNeg, falsePos, falseNeg));
    return loss;
}
Also used : LabeledData(com.tencent.angel.ml.feature.LabeledData)

Example 2 with LabeledData

use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.

the class LabeledUpdateIndexBaseTask method preProcess.

@Override
public void preProcess(TaskContext taskContext) {
    try {
        Reader<KEYIN, VALUEIN> reader = taskContext.getReader();
        while (reader.nextKeyValue()) {
            LabeledData out = parse(reader.getCurrentKey(), reader.getCurrentValue());
            if (out != null) {
                taskDataBlock.put(out);
                if (updateIndexEnable) {
                    TAbstractVector vector = out.getX();
                    if (vector instanceof SparseDummyVector) {
                        int[] indexes = ((SparseDummyVector) vector).getIndices();
                        for (int i = 0; i < indexes.length; i++) {
                            indexSet.add(indexes[i]);
                        }
                    }
                }
            }
        }
        taskDataBlock.flush();
    } catch (Exception e) {
        throw new AngelException("Pre-Process Error.", e);
    }
}
Also used : AngelException(com.tencent.angel.exception.AngelException) LabeledData(com.tencent.angel.ml.feature.LabeledData) TAbstractVector(com.tencent.angel.ml.math.TAbstractVector) SparseDummyVector(com.tencent.angel.ml.math.vector.SparseDummyVector) AngelException(com.tencent.angel.exception.AngelException) IOException(java.io.IOException)

Example 3 with LabeledData

use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.

the class ValidationUtils method calMSER2.

/**
 * Calculate MSE, RMSE, MAE and R2
 *
 * @param dataBlock
 * @param weight
 * @param lossFunc
 * @throws IOException
 * @throws InterruptedException
 */
public static Tuple4<Double, Double, Double, Double> calMSER2(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
    dataBlock.resetReadIndex();
    long startTime = System.currentTimeMillis();
    int totalNum = dataBlock.size();
    // the regression sum of squares
    double uLoss = 0.0;
    // the residual sum of squares
    double vLoss = 0.0;
    // the sum of true y
    double trueSum = 0.0;
    double maeLossSum = 0.0;
    for (int i = 0; i < totalNum; i++) {
        LabeledData data = dataBlock.get(i);
        double pre = lossFunc.predict(weight, data.getX());
        uLoss += Math.pow(lossFunc.loss(pre, data.getY()), 2);
        trueSum += data.getY();
        maeLossSum += Math.abs(data.getY() - pre);
    }
    double trueAvg = trueSum / totalNum;
    for (int i = 0; i < totalNum; i++) {
        LabeledData data = dataBlock.get(i);
        vLoss += Math.pow(lossFunc.loss(trueAvg, data.getY()), 2);
    }
    double MSE = uLoss / totalNum;
    double RMSE = Math.sqrt(MSE);
    double MAE = maeLossSum / totalNum;
    double R2 = 1 - uLoss / vLoss;
    LOG.info(String.format("validate %d samples cost %d ms, MSE= %.5f ,RMSE= %.5f ,MAE=%.5f ," + "R2= %.5f", totalNum, System.currentTimeMillis() - startTime, MSE, RMSE, MAE, R2));
    return new Tuple4<>(MSE, RMSE, MAE, R2);
}
Also used : Tuple4(scala.Tuple4) LabeledData(com.tencent.angel.ml.feature.LabeledData)

Example 4 with LabeledData

use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.

the class ValidationUtils method calMetrics.

/**
 * validate loss, AUC and precision
 *
 * @param dataBlock:  validation data taskDataBlock
 * @param weight:   the weight vector of features
 * @param lossFunc: the lossFunc used for prediction
 */
public static Tuple5<Double, Double, Double, Double, Double> calMetrics(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
    dataBlock.resetReadIndex();
    int totalNum = dataBlock.size();
    LOG.debug("Start calculate loss and auc, sample number: " + totalNum);
    long startTime = System.currentTimeMillis();
    double loss = 0.0;
    double[] scoresArray = new double[totalNum];
    double[] labelsArray = new double[totalNum];
    // ground truth: positive, precision: positive
    int truePos = 0;
    // ground truth: negative, precision: positive
    int falsePos = 0;
    // ground truth: negative, precision: negative
    int trueNeg = 0;
    // ground truth: positive, precision: negative
    int falseNeg = 0;
    for (int i = 0; i < totalNum; i++) {
        LabeledData data = dataBlock.read();
        double pre = lossFunc.predict(weight, data.getX());
        if (pre * data.getY() > 0) {
            if (pre > 0) {
                truePos++;
            } else {
                trueNeg++;
            }
        } else if (pre * data.getY() < 0) {
            if (pre > 0) {
                falsePos++;
            } else {
                falseNeg++;
            }
        }
        scoresArray[i] = pre;
        labelsArray[i] = data.getY();
        loss += lossFunc.loss(pre, data.getY());
    }
    loss += lossFunc.getReg(weight);
    double precision = (double) (truePos + trueNeg) / totalNum;
    Tuple3<Double, Double, Double> tuple3 = calAUC(scoresArray, labelsArray, truePos, trueNeg, falsePos, falseNeg);
    double aucResult = tuple3._1();
    double trueRecall = tuple3._2();
    double falseRecall = tuple3._3();
    return new Tuple5(loss, precision, aucResult, trueRecall, falseRecall);
}
Also used : Tuple5(scala.Tuple5) LabeledData(com.tencent.angel.ml.feature.LabeledData)

Aggregations

LabeledData (com.tencent.angel.ml.feature.LabeledData)4 AngelException (com.tencent.angel.exception.AngelException)1 TAbstractVector (com.tencent.angel.ml.math.TAbstractVector)1 SparseDummyVector (com.tencent.angel.ml.math.vector.SparseDummyVector)1 IOException (java.io.IOException)1 Tuple4 (scala.Tuple4)1 Tuple5 (scala.Tuple5)1