use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.
the class Sampler method sample.
public Future<VoidResult> sample(PartitionKey pkey, PartCSRResult csr, boolean update) {
int ws = pkey.getStartRow();
int we = pkey.getEndRow();
Random rand = new Random(System.currentTimeMillis());
Int2IntOpenHashMap[] updates = null;
try {
// allocate update maps
if (update)
updates = new Int2IntOpenHashMap[we - ws];
float sum, u;
int idx;
for (int w = ws; w < we; w++) {
// Skip if no token for this word
if (data.ws[w + 1] - data.ws[w] == 0)
continue;
// Check whether error when fetching word-topic
if (!csr.read(wk))
throw new AngelException("some error happens");
// Build FTree for current word
buildFTree();
if (update)
updates[w - ws] = new Int2IntOpenHashMap();
for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
// current doc
int d = data.docs[wi];
// old topic assignment
int tt = data.topics[data.dindex[wi]];
// We need to adjust the memory settings or network fetching parameters.
if (update && wk[tt] <= 0) {
LOG.error(String.format("Error wk[%d] = %d for word %d", tt, wk[tt], w));
continue;
}
// Update statistics if needed
if (update) {
wk[tt]--;
nk[tt]--;
tree.update(tt, (wk[tt] + beta) / (nk[tt] + vbeta));
updates[w - ws].addTo(tt, -1);
}
// Calculate psum and sample new topic
synchronized (data.docIds[d]) {
if (data.dks[d] == null)
sum = build(d, maxDoc, tree, tt);
else {
data.dks[d].dec(tt);
sum = build(data.dks[d]);
}
u = rand.nextFloat() * (sum + alpha * tree.first());
if (u < sum) {
u = rand.nextFloat() * sum;
if (data.dks[d] == null) {
int length = data.ds[d + 1] - data.ds[d];
idx = BinarySearch.binarySearch(maxDoc, u, 0, length - 1);
tt = data.topics[data.ds[d] + idx];
} else {
if (data.dks[d].size == 1)
tt = tidx[0];
else {
idx = BinarySearch.binarySearch(psum, u, 0, data.dks[d].size - 1);
tt = tidx[idx];
}
}
} else
tt = tree.sample(rand.nextFloat() * tree.first());
if (data.dks[d] != null)
data.dks[d].inc(tt);
}
// Update statistics if needed
if (update) {
wk[tt]++;
nk[tt]++;
tree.update(tt, (wk[tt] + beta) / (nk[tt] + vbeta));
updates[w - ws].addTo(tt, 1);
}
// Assign new topic
data.topics[data.dindex[wi]] = tt;
}
}
} finally {
csr.clear();
}
Future<VoidResult> future = null;
if (update) {
CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
}
return future;
}
use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.
the class Sampler method reset.
public Future<VoidResult> reset(PartitionKey pkey) {
int ws = pkey.getStartRow();
int es = pkey.getEndRow();
Int2IntOpenHashMap[] updates = new Int2IntOpenHashMap[es - ws];
for (int w = ws; w < es; w++) {
if (data.ws[w + 1] == data.ws[w])
continue;
updates[w - ws] = new Int2IntOpenHashMap();
for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
int tt = data.topics[data.dindex[wi]];
updates[w - ws].addTo(tt, 1);
nk[tt]++;
}
}
CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
return future;
}
use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.
the class Sampler method initialize.
public Future<VoidResult> initialize(PartitionKey pkey, boolean update) {
int ws = pkey.getStartRow();
int es = pkey.getEndRow();
Random rand = new Random(System.currentTimeMillis());
Int2IntOpenHashMap[] updates = null;
if (update)
updates = new Int2IntOpenHashMap[es - ws];
for (int w = ws; w < es; w++) {
// Skip if no token for this word
if (data.ws[w + 1] == data.ws[w])
continue;
if (update)
updates[w - ws] = new Int2IntOpenHashMap();
for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
int t = rand.nextInt(K);
data.topics[data.dindex[wi]] = t;
if (update) {
updates[w - ws].addTo(t, 1);
nk[t]++;
}
int d = data.docs[wi];
if (data.dks[d] != null) {
synchronized (data.dks[d]) {
data.dks[d].inc(t);
}
}
}
}
Future<VoidResult> future = null;
if (update) {
CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
}
return future;
}
use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.
the class UpdateUDFHandler method handle.
@Override
public void handle(FutureResult finalResult, UserRequest userRequest, ResponseCache responseCache) {
MapResponseCache cache = (MapResponseCache) responseCache;
// Check update result
Map<Request, Response> responses = cache.getResponses();
boolean success = true;
String detail = "";
for (Response response : responses.values()) {
success = success && (response.getResponseType() == ResponseType.SUCCESS);
if (!success) {
detail = response.getDetail();
break;
}
}
// Set the final result
if (success) {
finalResult.set(new VoidResult(com.tencent.angel.psagent.matrix.ResponseType.SUCCESS));
} else {
finalResult.set(new VoidResult(com.tencent.angel.psagent.matrix.ResponseType.FAILED, detail));
}
}
use of com.tencent.angel.ml.matrix.psf.update.base.VoidResult in project angel by Tencent.
the class MatrixTransportClient method update.
@Override
public FutureResult<VoidResult> update(UpdateFunc func, PartitionUpdateParam param) {
// Request header
RequestHeader header = createRequestHeader(-1, TransportMethod.UPDATE_PSF, param.getMatrixId(), param.getPartKey().getPartitionId());
// Request body
UpdateUDFRequest requestData = new UpdateUDFRequest(func.getClass().getName(), param);
// Request
Request request = new Request(header, requestData);
FutureResult<VoidResult> resultFuture = new FutureResult<>();
requestToResultMap.put(request, resultFuture);
// Send the request
sendUpdateRequest(request);
return resultFuture;
}
Aggregations