use of com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam in project angel by Tencent.
the class MatrixOpLogTest method testUDF.
@Test
public void testUDF() throws ServiceException, IOException, InvalidParameterException, AngelException, InterruptedException, ExecutionException {
Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
MatrixClient w1Task0Client = worker.getPSAgent().getMatrixClient("w1", 0);
MatrixClient w1Task1Client = worker.getPSAgent().getMatrixClient("w1", 1);
int matrixW1Id = w1Task0Client.getMatrixId();
List<Integer> rowIndexes = new ArrayList<Integer>();
for (int i = 0; i < 100; i++) {
rowIndexes.add(i);
}
GetRowsFunc func = new GetRowsFunc(new GetRowsParam(matrixW1Id, rowIndexes));
int[] delta = new int[100000];
for (int i = 0; i < 100000; i++) {
delta[i] = 1;
}
// DenseIntVector deltaVec = new DenseIntVector(100000, delta);
// deltaVec.setMatrixId(matrixW1Id);
// deltaVec.setRowId(0);
int index = 0;
while (index++ < 10) {
Map<Integer, TVector> rows = ((GetRowsResult) w1Task0Client.get(func)).getRows();
for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
LOG.info("index " + rowEntry.getKey() + " sum of w1 = " + sum((DenseIntVector) rowEntry.getValue()));
}
for (int i = 0; i < 100; i++) {
DenseIntVector deltaVec = new DenseIntVector(100000, delta);
deltaVec.setMatrixId(matrixW1Id);
deltaVec.setRowId(i);
w1Task0Client.increment(deltaVec);
deltaVec = new DenseIntVector(100000, delta);
deltaVec.setMatrixId(matrixW1Id);
deltaVec.setRowId(i);
w1Task1Client.increment(deltaVec);
}
w1Task0Client.clock().get();
w1Task1Client.clock().get();
}
}
use of com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam in project angel by Tencent.
the class ConsistencyController method getRowsFlow.
/**
* Get a batch of row from storage/cache or pss.
*
* @param taskContext task context
* @param rowIndex row indexes
* @param rpcBatchSize fetch row number in one rpc request
* @return GetRowsResult rows
* @throws Exception
*/
public GetRowsResult getRowsFlow(TaskContext taskContext, RowIndex rowIndex, int rpcBatchSize) throws Exception {
GetRowsResult result = new GetRowsResult();
if (rowIndex.getRowsNumber() == 0) {
LOG.error("need get rowId set is empty, just return");
result.fetchOver();
return result;
}
int staleness = getStaleness(rowIndex.getMatrixId());
if (staleness >= 0) {
// For BSP/SSP, get rows from storage/cache first
int stalnessClock = taskContext.getMatrixClock(rowIndex.getMatrixId()) - staleness;
findRowsInStorage(taskContext, result, rowIndex, stalnessClock);
if (!result.isFetchOver()) {
LOG.debug("need fetch from parameterserver");
// Get from ps.
PSAgentContext.get().getMatrixClientAdapter().getRowsFlow(result, rowIndex, rpcBatchSize, stalnessClock);
}
return result;
} else {
// For ASYNC, just get rows from pss.
IntOpenHashSet rowIdSet = rowIndex.getRowIds();
List<Integer> rowIndexes = new ArrayList<Integer>(rowIdSet.size());
rowIndexes.addAll(rowIdSet);
GetRowsFunc func = new GetRowsFunc(new GetRowsParam(rowIndex.getMatrixId(), rowIndexes));
com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult funcResult = ((com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult) PSAgentContext.get().getMatrixClientAdapter().get(func));
if (funcResult.getResponseType() == ResponseType.FAILED) {
throw new IOException("get rows from ps failed.");
} else {
Map<Integer, TVector> rows = funcResult.getRows();
for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
result.put(rowEntry.getValue());
}
result.fetchOver();
return result;
}
}
}
Aggregations