Search in sources :

Example 1 with CSRPartUpdateParam

use of com.tencent.angel.ml.lda.psf.CSRPartUpdateParam in project angel by Tencent.

the class Sampler method reset.

public Future<VoidResult> reset(PartitionKey pkey) {
    int ws = pkey.getStartRow();
    int es = pkey.getEndRow();
    Short2IntOpenHashMap[] updates = new Short2IntOpenHashMap[es - ws];
    for (int w = ws; w < es; w++) {
        if (data.ws[w + 1] == data.ws[w])
            continue;
        updates[w - ws] = new Short2IntOpenHashMap();
        for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
            int tt = data.topics[wi];
            updates[w - ws].addTo((short) tt, 1);
            nk[tt]++;
        }
    }
    CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
    Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
    return future;
}
Also used : UpdatePartFunc(com.tencent.angel.ml.lda.psf.UpdatePartFunc) CSRPartUpdateParam(com.tencent.angel.ml.lda.psf.CSRPartUpdateParam) VoidResult(com.tencent.angel.ml.matrix.psf.update.enhance.VoidResult) Short2IntOpenHashMap(it.unimi.dsi.fastutil.shorts.Short2IntOpenHashMap)

Example 2 with CSRPartUpdateParam

use of com.tencent.angel.ml.lda.psf.CSRPartUpdateParam in project angel by Tencent.

the class Sampler method initialize.

public Future<VoidResult> initialize(PartitionKey pkey) {
    int ws = pkey.getStartRow();
    int es = pkey.getEndRow();
    Random rand = new Random(System.currentTimeMillis());
    Short2IntOpenHashMap[] updates = new Short2IntOpenHashMap[es - ws];
    for (int w = ws; w < es; w++) {
        if (data.ws[w + 1] == data.ws[w])
            continue;
        updates[w - ws] = new Short2IntOpenHashMap();
        for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
            int d = data.docs[wi];
            int t = rand.nextInt(K);
            data.topics[wi] = t;
            nk[t]++;
            synchronized (data.dks[d]) {
                data.dks[d].inc(t);
            }
            // update.plusBy(t, 1);
            updates[w - ws].addTo((short) t, 1);
        }
    // model.wtMat().increment(w, update);
    }
    CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
    Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
    return future;
}
Also used : UpdatePartFunc(com.tencent.angel.ml.lda.psf.UpdatePartFunc) CSRPartUpdateParam(com.tencent.angel.ml.lda.psf.CSRPartUpdateParam) Random(java.util.Random) VoidResult(com.tencent.angel.ml.matrix.psf.update.enhance.VoidResult) Short2IntOpenHashMap(it.unimi.dsi.fastutil.shorts.Short2IntOpenHashMap)

Example 3 with CSRPartUpdateParam

use of com.tencent.angel.ml.lda.psf.CSRPartUpdateParam in project angel by Tencent.

the class Sampler method sample.

public Future<VoidResult> sample(PartitionKey pkey, PartCSRResult csr) {
    int ws = pkey.getStartRow();
    int we = pkey.getEndRow();
    Random rand = new Random(System.currentTimeMillis());
    Short2IntOpenHashMap[] updates = new Short2IntOpenHashMap[we - ws];
    for (int w = ws; w < we; w++) {
        if (data.ws[w + 1] - data.ws[w] == 0)
            continue;
        if (!csr.read(wk))
            throw new AngelException("some error happens");
        buildFTree();
        updates[w - ws] = new Short2IntOpenHashMap();
        for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
            int d = data.docs[wi];
            TraverseHashMap dk = data.dks[d];
            int tt = data.topics[wi];
            if (wk[tt] <= 0) {
                LOG.error(String.format("Error wk[%d] = %d for word %d", tt, wk[tt], w));
                continue;
            }
            wk[tt]--;
            nk[tt]--;
            float value = (wk[tt] + beta) / (nk[tt] + vbeta);
            tree.update(tt, value);
            updates[w - ws].addTo((short) tt, -1);
            synchronized (dk) {
                dk.dec(tt);
                float sum = build(dk);
                float u = rand.nextFloat() * (sum + alpha * tree.first());
                if (u < sum) {
                    u = rand.nextFloat() * sum;
                    int idx = BinarySearch.binarySearch(psum, u, 0, dk.size - 1);
                    tt = tidx[idx];
                } else
                    tt = tree.sample(rand.nextFloat() * tree.first());
                dk.inc(tt);
            }
            wk[tt]++;
            nk[tt]++;
            value = (wk[tt] + beta) / (nk[tt] + vbeta);
            tree.update(tt, value);
            data.topics[wi] = tt;
            updates[w - ws].addTo((short) tt, 1);
        }
    // model.wtMat().increment(w, update);
    }
    CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
    Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
    return future;
}
Also used : AngelException(com.tencent.angel.exception.AngelException) UpdatePartFunc(com.tencent.angel.ml.lda.psf.UpdatePartFunc) CSRPartUpdateParam(com.tencent.angel.ml.lda.psf.CSRPartUpdateParam) Random(java.util.Random) VoidResult(com.tencent.angel.ml.matrix.psf.update.enhance.VoidResult) Short2IntOpenHashMap(it.unimi.dsi.fastutil.shorts.Short2IntOpenHashMap)

Aggregations

CSRPartUpdateParam (com.tencent.angel.ml.lda.psf.CSRPartUpdateParam)3 UpdatePartFunc (com.tencent.angel.ml.lda.psf.UpdatePartFunc)3 VoidResult (com.tencent.angel.ml.matrix.psf.update.enhance.VoidResult)3 Short2IntOpenHashMap (it.unimi.dsi.fastutil.shorts.Short2IntOpenHashMap)3 Random (java.util.Random)2 AngelException (com.tencent.angel.exception.AngelException)1