use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.
the class ValidationUtils method calLossPrecision.
/**
* validate loss and precision
*
* @param dataBlock: validation data taskDataBlock
* @param weight: the weight vector of features
* @param lossFunc: the lossFunc used for prediction
*/
public static double calLossPrecision(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
dataBlock.resetReadIndex();
long startTime = System.currentTimeMillis();
int totalNum = dataBlock.size();
double loss = 0.0;
// ground truth: positive, precision: positive
int truePos = 0;
// ground truth: negative, precision: positive
int falsePos = 0;
// ground truth: negative, precision: negative
int trueNeg = 0;
// ground truth: positive, precision: negative
int falseNeg = 0;
for (int i = 0; i < totalNum; i++) {
LabeledData data = dataBlock.get(i);
double pre = lossFunc.predict(weight, data.getX());
if (pre * data.getY() > 0) {
if (pre > 0) {
truePos++;
} else {
trueNeg++;
}
} else if (pre * data.getY() < 0) {
if (pre > 0) {
falsePos++;
} else {
falseNeg++;
}
}
loss += lossFunc.loss(pre, data.getY());
}
long cost = System.currentTimeMillis() - startTime;
double precision = (double) (truePos + trueNeg) / totalNum;
double trueRecall = (double) truePos / (truePos + falseNeg);
double falseRecall = (double) trueNeg / (trueNeg + falsePos);
LOG.debug(String.format("validate cost %d ms, loss= %.5f, precision=%.5f, trueRecall=%.5f, " + "falseRecall=%.5f", totalNum, cost, loss, precision, trueRecall, falseRecall));
LOG.debug(String.format("Validation TP=%d, TN=%d, FP=%d, FN=%d", truePos, trueNeg, falsePos, falseNeg));
return loss;
}
use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.
the class LabeledUpdateIndexBaseTask method preProcess.
@Override
public void preProcess(TaskContext taskContext) {
try {
Reader<KEYIN, VALUEIN> reader = taskContext.getReader();
while (reader.nextKeyValue()) {
LabeledData out = parse(reader.getCurrentKey(), reader.getCurrentValue());
if (out != null) {
taskDataBlock.put(out);
if (updateIndexEnable) {
TAbstractVector vector = out.getX();
if (vector instanceof SparseDummyVector) {
int[] indexes = ((SparseDummyVector) vector).getIndices();
for (int i = 0; i < indexes.length; i++) {
indexSet.add(indexes[i]);
}
}
}
}
}
taskDataBlock.flush();
} catch (Exception e) {
throw new AngelException("Pre-Process Error.", e);
}
}
use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.
the class ValidationUtils method calMSER2.
/**
* Calculate MSE, RMSE, MAE and R2
*
* @param dataBlock
* @param weight
* @param lossFunc
* @throws IOException
* @throws InterruptedException
*/
public static Tuple4<Double, Double, Double, Double> calMSER2(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
dataBlock.resetReadIndex();
long startTime = System.currentTimeMillis();
int totalNum = dataBlock.size();
// the regression sum of squares
double uLoss = 0.0;
// the residual sum of squares
double vLoss = 0.0;
// the sum of true y
double trueSum = 0.0;
double maeLossSum = 0.0;
for (int i = 0; i < totalNum; i++) {
LabeledData data = dataBlock.get(i);
double pre = lossFunc.predict(weight, data.getX());
uLoss += Math.pow(lossFunc.loss(pre, data.getY()), 2);
trueSum += data.getY();
maeLossSum += Math.abs(data.getY() - pre);
}
double trueAvg = trueSum / totalNum;
for (int i = 0; i < totalNum; i++) {
LabeledData data = dataBlock.get(i);
vLoss += Math.pow(lossFunc.loss(trueAvg, data.getY()), 2);
}
double MSE = uLoss / totalNum;
double RMSE = Math.sqrt(MSE);
double MAE = maeLossSum / totalNum;
double R2 = 1 - uLoss / vLoss;
LOG.info(String.format("validate %d samples cost %d ms, MSE= %.5f ,RMSE= %.5f ,MAE=%.5f ," + "R2= %.5f", totalNum, System.currentTimeMillis() - startTime, MSE, RMSE, MAE, R2));
return new Tuple4<>(MSE, RMSE, MAE, R2);
}
use of com.tencent.angel.ml.feature.LabeledData in project angel by Tencent.
the class ValidationUtils method calMetrics.
/**
* validate loss, AUC and precision
*
* @param dataBlock: validation data taskDataBlock
* @param weight: the weight vector of features
* @param lossFunc: the lossFunc used for prediction
*/
public static Tuple5<Double, Double, Double, Double, Double> calMetrics(DataBlock<LabeledData> dataBlock, TDoubleVector weight, Loss lossFunc) throws IOException, InterruptedException {
dataBlock.resetReadIndex();
int totalNum = dataBlock.size();
LOG.debug("Start calculate loss and auc, sample number: " + totalNum);
long startTime = System.currentTimeMillis();
double loss = 0.0;
double[] scoresArray = new double[totalNum];
double[] labelsArray = new double[totalNum];
// ground truth: positive, precision: positive
int truePos = 0;
// ground truth: negative, precision: positive
int falsePos = 0;
// ground truth: negative, precision: negative
int trueNeg = 0;
// ground truth: positive, precision: negative
int falseNeg = 0;
for (int i = 0; i < totalNum; i++) {
LabeledData data = dataBlock.read();
double pre = lossFunc.predict(weight, data.getX());
if (pre * data.getY() > 0) {
if (pre > 0) {
truePos++;
} else {
trueNeg++;
}
} else if (pre * data.getY() < 0) {
if (pre > 0) {
falsePos++;
} else {
falseNeg++;
}
}
scoresArray[i] = pre;
labelsArray[i] = data.getY();
loss += lossFunc.loss(pre, data.getY());
}
loss += lossFunc.getReg(weight);
double precision = (double) (truePos + trueNeg) / totalNum;
Tuple3<Double, Double, Double> tuple3 = calAUC(scoresArray, labelsArray, truePos, trueNeg, falsePos, falseNeg);
double aucResult = tuple3._1();
double trueRecall = tuple3._2();
double falseRecall = tuple3._3();
return new Tuple5(loss, precision, aucResult, trueRecall, falseRecall);
}
Aggregations