Search in sources :

Example 1 with VoidResult

use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.

the class Sampler method sample.

public Future<VoidResult> sample(PartitionKey pkey, PartCSRResult csr, boolean update) {
    int ws = pkey.getStartRow();
    int we = pkey.getEndRow();
    Random rand = new Random(System.currentTimeMillis());
    Int2IntOpenHashMap[] updates = null;
    try {
        // allocate update maps
        if (update)
            updates = new Int2IntOpenHashMap[we - ws];
        float sum, u;
        int idx;
        for (int w = ws; w < we; w++) {
            // Skip if no token for this word
            if (data.ws[w + 1] - data.ws[w] == 0)
                continue;
            // Check whether error when fetching word-topic
            if (!csr.read(wk))
                throw new AngelException("some error happens");
            // Build FTree for current word
            buildFTree();
            if (update)
                updates[w - ws] = new Int2IntOpenHashMap();
            for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
                // current doc
                int d = data.docs[wi];
                // old topic assignment
                int tt = data.topics[data.dindex[wi]];
                // We need to adjust the memory settings or network fetching parameters.
                if (update && wk[tt] <= 0) {
                    LOG.error(String.format("Error wk[%d] = %d for word %d", tt, wk[tt], w));
                    continue;
                }
                // Update statistics if needed
                if (update) {
                    wk[tt]--;
                    nk[tt]--;
                    tree.update(tt, (wk[tt] + beta) / (nk[tt] + vbeta));
                    updates[w - ws].addTo(tt, -1);
                }
                // Calculate psum and sample new topic
                synchronized (data.docIds[d]) {
                    if (data.dks[d] == null)
                        sum = build(d, maxDoc, tree, tt);
                    else {
                        data.dks[d].dec(tt);
                        sum = build(data.dks[d]);
                    }
                    u = rand.nextFloat() * (sum + alpha * tree.first());
                    if (u < sum) {
                        u = rand.nextFloat() * sum;
                        if (data.dks[d] == null) {
                            int length = data.ds[d + 1] - data.ds[d];
                            idx = BinarySearch.binarySearch(maxDoc, u, 0, length - 1);
                            tt = data.topics[data.ds[d] + idx];
                        } else {
                            if (data.dks[d].size == 1)
                                tt = tidx[0];
                            else {
                                idx = BinarySearch.binarySearch(psum, u, 0, data.dks[d].size - 1);
                                tt = tidx[idx];
                            }
                        }
                    } else
                        tt = tree.sample(rand.nextFloat() * tree.first());
                    if (data.dks[d] != null)
                        data.dks[d].inc(tt);
                }
                // Update statistics if needed
                if (update) {
                    wk[tt]++;
                    nk[tt]++;
                    tree.update(tt, (wk[tt] + beta) / (nk[tt] + vbeta));
                    updates[w - ws].addTo(tt, 1);
                }
                // Assign new topic
                data.topics[data.dindex[wi]] = tt;
            }
        }
    } finally {
        csr.clear();
    }
    Future<VoidResult> future = null;
    if (update) {
        CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
        future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
    }
    return future;
}
Also used : AngelException(com.tencent.angel.exception.AngelException) UpdatePartFunc(com.tencent.angel.ml.lda.psf.UpdatePartFunc) CSRPartUpdateParam(com.tencent.angel.ml.lda.psf.CSRPartUpdateParam) Random(java.util.Random) VoidResult(com.tencent.angel.ml.matrix.psf.update.base.VoidResult) Int2IntOpenHashMap(it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap)

Example 2 with VoidResult

use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.

the class Sampler method reset.

public Future<VoidResult> reset(PartitionKey pkey) {
    int ws = pkey.getStartRow();
    int es = pkey.getEndRow();
    Int2IntOpenHashMap[] updates = new Int2IntOpenHashMap[es - ws];
    for (int w = ws; w < es; w++) {
        if (data.ws[w + 1] == data.ws[w])
            continue;
        updates[w - ws] = new Int2IntOpenHashMap();
        for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
            int tt = data.topics[data.dindex[wi]];
            updates[w - ws].addTo(tt, 1);
            nk[tt]++;
        }
    }
    CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
    Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
    return future;
}
Also used : UpdatePartFunc(com.tencent.angel.ml.lda.psf.UpdatePartFunc) CSRPartUpdateParam(com.tencent.angel.ml.lda.psf.CSRPartUpdateParam) VoidResult(com.tencent.angel.ml.matrix.psf.update.base.VoidResult) Int2IntOpenHashMap(it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap)

Example 3 with VoidResult

use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.

the class Sampler method initialize.

public Future<VoidResult> initialize(PartitionKey pkey, boolean update) {
    int ws = pkey.getStartRow();
    int es = pkey.getEndRow();
    Random rand = new Random(System.currentTimeMillis());
    Int2IntOpenHashMap[] updates = null;
    if (update)
        updates = new Int2IntOpenHashMap[es - ws];
    for (int w = ws; w < es; w++) {
        // Skip if no token for this word
        if (data.ws[w + 1] == data.ws[w])
            continue;
        if (update)
            updates[w - ws] = new Int2IntOpenHashMap();
        for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
            int t = rand.nextInt(K);
            data.topics[data.dindex[wi]] = t;
            if (update) {
                updates[w - ws].addTo(t, 1);
                nk[t]++;
            }
            int d = data.docs[wi];
            if (data.dks[d] != null) {
                synchronized (data.dks[d]) {
                    data.dks[d].inc(t);
                }
            }
        }
    }
    Future<VoidResult> future = null;
    if (update) {
        CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
        future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
    }
    return future;
}
Also used : UpdatePartFunc(com.tencent.angel.ml.lda.psf.UpdatePartFunc) CSRPartUpdateParam(com.tencent.angel.ml.lda.psf.CSRPartUpdateParam) Random(java.util.Random) VoidResult(com.tencent.angel.ml.matrix.psf.update.base.VoidResult) Int2IntOpenHashMap(it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap)

Example 4 with VoidResult

use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.

the class UpdateUDFHandler method handle.

@Override
public void handle(FutureResult finalResult, UserRequest userRequest, ResponseCache responseCache) {
    MapResponseCache cache = (MapResponseCache) responseCache;
    // Check update result
    Map<Request, Response> responses = cache.getResponses();
    boolean success = true;
    String detail = "";
    for (Response response : responses.values()) {
        success = success && (response.getResponseType() == ResponseType.SUCCESS);
        if (!success) {
            detail = response.getDetail();
            break;
        }
    }
    // Set the final result
    if (success) {
        finalResult.set(new VoidResult(com.tencent.angel.psagent.matrix.ResponseType.SUCCESS));
    } else {
        finalResult.set(new VoidResult(com.tencent.angel.psagent.matrix.ResponseType.FAILED, detail));
    }
}
Also used : Response(com.tencent.angel.ps.server.data.response.Response) UpdateUDFResponse(com.tencent.angel.ps.server.data.response.UpdateUDFResponse) VoidResult(com.tencent.angel.ml.matrix.psf.update.base.VoidResult) Request(com.tencent.angel.ps.server.data.request.Request) UserRequest(com.tencent.angel.psagent.matrix.transport.adapter.UserRequest)

Example 5 with VoidResult

use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.

the class MatrixTransportClient method update.

@Override
public FutureResult<VoidResult> update(UpdateFunc func, PartitionUpdateParam param) {
    // Request header
    RequestHeader header = createRequestHeader(-1, TransportMethod.UPDATE_PSF, param.getMatrixId(), param.getPartKey().getPartitionId());
    // Request body
    UpdateUDFRequest requestData = new UpdateUDFRequest(func.getClass().getName(), param);
    // Request
    Request request = new Request(header, requestData);
    FutureResult<VoidResult> resultFuture = new FutureResult<>();
    requestToResultMap.put(request, resultFuture);
    // Send the request
    sendUpdateRequest(request);
    return resultFuture;
}
Also used : VoidResult(com.tencent.angel.ml.matrix.psf.update.base.VoidResult) UserRequest(com.tencent.angel.psagent.matrix.transport.adapter.UserRequest) Request(com.tencent.angel.ps.server.data.request.Request) UpdateUDFRequest(com.tencent.angel.ps.server.data.request.UpdateUDFRequest) GetUDFRequest(com.tencent.angel.ps.server.data.request.GetUDFRequest) RequestHeader(com.tencent.angel.ps.server.data.request.RequestHeader) UpdateUDFRequest(com.tencent.angel.ps.server.data.request.UpdateUDFRequest)

Aggregations

VoidResult (com.tencent.angel.ml.matrix.psf.update.base.VoidResult)14 FutureResult (com.tencent.angel.psagent.matrix.transport.FutureResult)8 MatrixTransportClient (com.tencent.angel.psagent.matrix.transport.MatrixTransportClient)4 MapResponseCache (com.tencent.angel.psagent.matrix.transport.response.MapResponseCache)4 ResponseCache (com.tencent.angel.psagent.matrix.transport.response.ResponseCache)4 PartitionKey (com.tencent.angel.PartitionKey)3 CSRPartUpdateParam (com.tencent.angel.ml.lda.psf.CSRPartUpdateParam)3 UpdatePartFunc (com.tencent.angel.ml.lda.psf.UpdatePartFunc)3 MatrixMeta (com.tencent.angel.ml.matrix.MatrixMeta)3 CompStreamKeyValuePart (com.tencent.angel.psagent.matrix.transport.router.CompStreamKeyValuePart)3 Int2IntOpenHashMap (it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap)3 AngelException (com.tencent.angel.exception.AngelException)2 Request (com.tencent.angel.ps.server.data.request.Request)2 Response (com.tencent.angel.ps.server.data.response.Response)2 UserRequest (com.tencent.angel.psagent.matrix.transport.adapter.UserRequest)2 Random (java.util.Random)2 PartitionUpdateParam (com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam)1 UpdateParam (com.tencent.angel.ml.matrix.psf.update.base.UpdateParam)1 GetUDFRequest (com.tencent.angel.ps.server.data.request.GetUDFRequest)1 RequestHeader (com.tencent.angel.ps.server.data.request.RequestHeader)1