use of com.tencent.angel.ml.lda.psf.CSRPartUpdateParam in project angel by Tencent.
the class Sampler method reset.
public Future<VoidResult> reset(PartitionKey pkey) {
int ws = pkey.getStartRow();
int es = pkey.getEndRow();
Short2IntOpenHashMap[] updates = new Short2IntOpenHashMap[es - ws];
for (int w = ws; w < es; w++) {
if (data.ws[w + 1] == data.ws[w])
continue;
updates[w - ws] = new Short2IntOpenHashMap();
for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
int tt = data.topics[wi];
updates[w - ws].addTo((short) tt, 1);
nk[tt]++;
}
}
CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
return future;
}
use of com.tencent.angel.ml.lda.psf.CSRPartUpdateParam in project angel by Tencent.
the class Sampler method initialize.
public Future<VoidResult> initialize(PartitionKey pkey) {
int ws = pkey.getStartRow();
int es = pkey.getEndRow();
Random rand = new Random(System.currentTimeMillis());
Short2IntOpenHashMap[] updates = new Short2IntOpenHashMap[es - ws];
for (int w = ws; w < es; w++) {
if (data.ws[w + 1] == data.ws[w])
continue;
updates[w - ws] = new Short2IntOpenHashMap();
for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
int d = data.docs[wi];
int t = rand.nextInt(K);
data.topics[wi] = t;
nk[t]++;
synchronized (data.dks[d]) {
data.dks[d].inc(t);
}
// update.plusBy(t, 1);
updates[w - ws].addTo((short) t, 1);
}
// model.wtMat().increment(w, update);
}
CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
return future;
}
use of com.tencent.angel.ml.lda.psf.CSRPartUpdateParam in project angel by Tencent.
the class Sampler method sample.
public Future<VoidResult> sample(PartitionKey pkey, PartCSRResult csr) {
int ws = pkey.getStartRow();
int we = pkey.getEndRow();
Random rand = new Random(System.currentTimeMillis());
Short2IntOpenHashMap[] updates = new Short2IntOpenHashMap[we - ws];
for (int w = ws; w < we; w++) {
if (data.ws[w + 1] - data.ws[w] == 0)
continue;
if (!csr.read(wk))
throw new AngelException("some error happens");
buildFTree();
updates[w - ws] = new Short2IntOpenHashMap();
for (int wi = data.ws[w]; wi < data.ws[w + 1]; wi++) {
int d = data.docs[wi];
TraverseHashMap dk = data.dks[d];
int tt = data.topics[wi];
if (wk[tt] <= 0) {
LOG.error(String.format("Error wk[%d] = %d for word %d", tt, wk[tt], w));
continue;
}
wk[tt]--;
nk[tt]--;
float value = (wk[tt] + beta) / (nk[tt] + vbeta);
tree.update(tt, value);
updates[w - ws].addTo((short) tt, -1);
synchronized (dk) {
dk.dec(tt);
float sum = build(dk);
float u = rand.nextFloat() * (sum + alpha * tree.first());
if (u < sum) {
u = rand.nextFloat() * sum;
int idx = BinarySearch.binarySearch(psum, u, 0, dk.size - 1);
tt = tidx[idx];
} else
tt = tree.sample(rand.nextFloat() * tree.first());
dk.inc(tt);
}
wk[tt]++;
nk[tt]++;
value = (wk[tt] + beta) / (nk[tt] + vbeta);
tree.update(tt, value);
data.topics[wi] = tt;
updates[w - ws].addTo((short) tt, 1);
}
// model.wtMat().increment(w, update);
}
CSRPartUpdateParam param = new CSRPartUpdateParam(model.wtMat().getMatrixId(), pkey, updates);
Future<VoidResult> future = PSAgentContext.get().getMatrixTransportClient().update(new UpdatePartFunc(null), param);
return future;
}
Aggregations