Search in sources :

Example 1 with CompressUpdateFunc

use of com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc in project angel by Tencent.

the class GradHistThread method run.

@Override
public void run() {
    LOG.debug(String.format("Run active node[%d]", this.nid));
    // 1. name of this node's grad histogram on PS
    String histParaName = this.controller.param.gradHistNamePrefix + nid;
    // 2. build the grad histogram of this node
    GradHistHelper histMaker = new GradHistHelper(this.controller, this.nid);
    DenseDoubleVector histogram = histMaker.buildHistogram(insStart, insEnd);
    int bytesPerItem = this.controller.taskContext.getConf().getInt(MLConf.ML_COMPRESS_BYTES(), MLConf.DEFAULT_ML_COMPRESS_BYTES());
    if (bytesPerItem < 1 || bytesPerItem > 8) {
        LOG.info("Invalid compress configuration: " + bytesPerItem + ", it should be [1,8].");
        bytesPerItem = MLConf.DEFAULT_ML_COMPRESS_BYTES();
    }
    // 3. push the histograms to PS
    try {
        if (bytesPerItem == 8) {
            this.model.increment(0, histogram);
        } else {
            CompressUpdateFunc func = new CompressUpdateFunc(this.model.getMatrixId(), 0, histogram, bytesPerItem * 8);
            this.model.update(func);
        }
    } catch (Exception e) {
        LOG.error(histParaName + " increment failed, ", e);
    }
    // 4. reset thread stats to finished
    this.controller.activeNodeStat[this.nid]--;
    LOG.debug(String.format("Active node[%d] finish", this.nid));
}
Also used : DenseDoubleVector(com.tencent.angel.ml.math.vector.DenseDoubleVector) CompressUpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc) GradHistHelper(com.tencent.angel.ml.GBDT.algo.RegTree.GradHistHelper)

Example 2 with CompressUpdateFunc

use of com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc in project angel by Tencent.

the class UpdateFuncTest method testCompress.

@Test
public void testCompress() throws Exception {
    UpdateFunc func = new CompressUpdateFunc(w2Client.getMatrixId(), 5, localArray1, 8);
    w2Client.update(func).get();
    int maxPoint = (int) Math.pow(2, 8 - 1) - 1;
    double maxMaxAbs = 0.0;
    for (int i = 0; i < localArray1.length; i++) {
        maxMaxAbs = Math.abs(localArray1[i]) > maxMaxAbs ? Math.abs(localArray1[i]) : maxMaxAbs;
    }
    double[] result = pull(w2Client, 5);
    assert (result.length == dim);
    for (int i = 0; i < result.length; i++) {
        Assert.assertEquals(localArray1[i], 0.0 + result[i], 2 * maxMaxAbs / maxPoint);
    }
}
Also used : CompressUpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc) UpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc) CompressUpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc) Test(org.junit.Test)

Aggregations

CompressUpdateFunc (com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc)2 GradHistHelper (com.tencent.angel.ml.GBDT.algo.RegTree.GradHistHelper)1 DenseDoubleVector (com.tencent.angel.ml.math.vector.DenseDoubleVector)1 UpdateFunc (com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc)1 Test (org.junit.Test)1