use of com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult in project angel by Tencent.
the class TransportTest method testGetFlowDenseDoubleMatrix.
@Test
public void testGetFlowDenseDoubleMatrix() throws Exception {
try {
Worker worker = LocalClusterContext.get().getWorker(worker0Attempt0Id).getWorker();
MatrixClient mat = worker.getPSAgent().getMatrixClient("dense_double_mat_1", 0);
double[][] data = new double[ddRow][ddCol];
DenseDoubleMatrix expect = new DenseDoubleMatrix(ddRow, ddCol, data);
RowIndex rowIndex = new RowIndex();
for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
GetRowsResult result = mat.getRowsFlow(rowIndex, ddRow / 2);
TVector row;
while ((row = result.take()) != null) {
LOG.info("===========get row index=" + row.getRowId());
assertArrayEquals(((DenseDoubleVector) expect.getRow(row.getRowId())).getValues(), ((DenseDoubleVector) row).getValues(), 0.0);
}
Random rand = new Random(System.currentTimeMillis());
for (int rowId = 0; rowId < ddRow; rowId++) {
DenseDoubleVector update = new DenseDoubleVector(ddCol);
for (int j = 0; j < ddCol; j += 3) update.set(j, rand.nextDouble());
mat.increment(rowId, update);
expect.getRow(rowId).plusBy(update);
}
mat.clock().get();
rowIndex = new RowIndex();
for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
result = mat.getRowsFlow(rowIndex, 2);
while ((row = result.take()) != null) {
assertArrayEquals(((DenseDoubleVector) expect.getRow(row.getRowId())).getValues(), ((DenseDoubleVector) row).getValues(), 0.0);
}
rowIndex = new RowIndex();
for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
result = mat.getRowsFlow(rowIndex, 2);
while (true) {
row = result.poll();
if (result.isFetchOver() && row == null)
break;
if (row == null)
continue;
assertArrayEquals(((DenseDoubleVector) expect.getRow(row.getRowId())).getValues(), ((DenseDoubleVector) row).getValues(), 0.0);
}
} catch (Exception x) {
LOG.error("run testGetFlowDenseDoubleMatrix failed ", x);
throw x;
}
}
use of com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult in project angel by Tencent.
the class TransportTest method testGetFlowDenseIntMatrix.
@Test
public void testGetFlowDenseIntMatrix() throws Exception {
try {
Worker worker = LocalClusterContext.get().getWorker(worker0Attempt0Id).getWorker();
MatrixClient mat = worker.getPSAgent().getMatrixClient("dense_int_mat_1", 0);
DenseIntMatrix expect = new DenseIntMatrix(diRow, diCol);
RowIndex rowIndex = new RowIndex();
for (int i = 0; i < diRow; i++) rowIndex.addRowId(i);
GetRowsResult result = mat.getRowsFlow(rowIndex, diRow / 2);
TVector row;
while ((row = result.take()) != null) {
assertArrayEquals(((DenseIntVector) expect.getRow(row.getRowId())).getValues(), ((DenseIntVector) row).getValues());
}
Random rand = new Random(System.currentTimeMillis());
for (int rowId = 0; rowId < diRow; rowId++) {
DenseIntVector update = new DenseIntVector(diCol);
for (int j = 0; j < ddCol; j += 3) update.set(j, rand.nextInt());
mat.increment(rowId, update);
expect.getRow(rowId).plusBy(update);
}
mat.clock().get();
rowIndex = new RowIndex();
for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
result = mat.getRowsFlow(rowIndex, 2);
while ((row = result.take()) != null) {
assertArrayEquals(((DenseIntVector) expect.getRow(row.getRowId())).getValues(), ((DenseIntVector) row).getValues());
}
rowIndex = new RowIndex();
for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
result = mat.getRowsFlow(rowIndex, 2);
while (true) {
row = result.poll();
if (result.isFetchOver() && row == null)
break;
if (row == null)
continue;
assertArrayEquals(((DenseIntVector) expect.getRow(row.getRowId())).getValues(), ((DenseIntVector) row).getValues());
}
} catch (Exception x) {
LOG.error("run testGetFlowDenseIntMatrix failed ", x);
throw x;
}
}
use of com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult in project angel by Tencent.
the class ConsistencyController method getRowsFlow.
/**
* Get a batch of row from storage/cache or pss.
*
* @param taskContext task context
* @param rowIndex row indexes
* @param rpcBatchSize fetch row number in one rpc request
* @return GetRowsResult rows
* @throws Exception
*/
public GetRowsResult getRowsFlow(TaskContext taskContext, RowIndex rowIndex, int rpcBatchSize) throws Exception {
GetRowsResult result = new GetRowsResult();
if (rowIndex.getRowsNumber() == 0) {
LOG.error("need get rowId set is empty, just return");
result.fetchOver();
return result;
}
int staleness = getStaleness(rowIndex.getMatrixId());
if (staleness >= 0) {
// For BSP/SSP, get rows from storage/cache first
int stalnessClock = taskContext.getMatrixClock(rowIndex.getMatrixId()) - staleness;
findRowsInStorage(taskContext, result, rowIndex, stalnessClock);
if (!result.isFetchOver()) {
LOG.debug("need fetch from parameterserver");
// Get from ps.
PSAgentContext.get().getMatrixClientAdapter().getRowsFlow(result, rowIndex, rpcBatchSize, stalnessClock);
}
return result;
} else {
// For ASYNC, just get rows from pss.
IntOpenHashSet rowIdSet = rowIndex.getRowIds();
List<Integer> rowIndexes = new ArrayList<Integer>(rowIdSet.size());
rowIndexes.addAll(rowIdSet);
GetRowsFunc func = new GetRowsFunc(new GetRowsParam(rowIndex.getMatrixId(), rowIndexes));
com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult funcResult = ((com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult) PSAgentContext.get().getMatrixClientAdapter().get(func));
if (funcResult.getResponseType() == ResponseType.FAILED) {
throw new IOException("get rows from ps failed.");
} else {
Map<Integer, TVector> rows = funcResult.getRows();
for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
result.put(rowEntry.getValue());
}
result.fetchOver();
return result;
}
}
}
Aggregations