Search in sources :

Example 1 with GetRowsFunc

use of com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc in project angel by Tencent.

the class MatrixOpLogTest method testUDF.

@Test
public void testUDF() throws ServiceException, IOException, InvalidParameterException, AngelException, InterruptedException, ExecutionException {
    Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
    MatrixClient w1Task0Client = worker.getPSAgent().getMatrixClient("w1", 0);
    MatrixClient w1Task1Client = worker.getPSAgent().getMatrixClient("w1", 1);
    int matrixW1Id = w1Task0Client.getMatrixId();
    List<Integer> rowIndexes = new ArrayList<Integer>();
    for (int i = 0; i < 100; i++) {
        rowIndexes.add(i);
    }
    GetRowsFunc func = new GetRowsFunc(new GetRowsParam(matrixW1Id, rowIndexes));
    int[] delta = new int[100000];
    for (int i = 0; i < 100000; i++) {
        delta[i] = 1;
    }
    // DenseIntVector deltaVec = new DenseIntVector(100000, delta);
    // deltaVec.setMatrixId(matrixW1Id);
    // deltaVec.setRowId(0);
    int index = 0;
    while (index++ < 10) {
        Map<Integer, TVector> rows = ((GetRowsResult) w1Task0Client.get(func)).getRows();
        for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
            LOG.info("index " + rowEntry.getKey() + " sum of w1 = " + sum((DenseIntVector) rowEntry.getValue()));
        }
        for (int i = 0; i < 100; i++) {
            DenseIntVector deltaVec = new DenseIntVector(100000, delta);
            deltaVec.setMatrixId(matrixW1Id);
            deltaVec.setRowId(i);
            w1Task0Client.increment(deltaVec);
            deltaVec = new DenseIntVector(100000, delta);
            deltaVec.setMatrixId(matrixW1Id);
            deltaVec.setRowId(i);
            w1Task1Client.increment(deltaVec);
        }
        w1Task0Client.clock().get();
        w1Task1Client.clock().get();
    }
}
Also used : GetRowsResult(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult) ArrayList(java.util.ArrayList) DenseIntVector(com.tencent.angel.ml.math.vector.DenseIntVector) GetRowsParam(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam) GetRowsFunc(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc) Worker(com.tencent.angel.worker.Worker) MatrixClient(com.tencent.angel.psagent.matrix.MatrixClient) TVector(com.tencent.angel.ml.math.TVector) Test(org.junit.Test)

Example 2 with GetRowsFunc

use of com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc in project angel by Tencent.

the class ConsistencyController method getRowsFlow.

/**
 * Get a batch of row from storage/cache or pss.
 *
 * @param taskContext task context
 * @param rowIndex row indexes
 * @param rpcBatchSize fetch row number in one rpc request
 * @return GetRowsResult rows
 * @throws Exception
 */
public GetRowsResult getRowsFlow(TaskContext taskContext, RowIndex rowIndex, int rpcBatchSize) throws Exception {
    GetRowsResult result = new GetRowsResult();
    if (rowIndex.getRowsNumber() == 0) {
        LOG.error("need get rowId set is empty, just return");
        result.fetchOver();
        return result;
    }
    int staleness = getStaleness(rowIndex.getMatrixId());
    if (staleness >= 0) {
        // For BSP/SSP, get rows from storage/cache first
        int stalnessClock = taskContext.getMatrixClock(rowIndex.getMatrixId()) - staleness;
        findRowsInStorage(taskContext, result, rowIndex, stalnessClock);
        if (!result.isFetchOver()) {
            LOG.debug("need fetch from parameterserver");
            // Get from ps.
            PSAgentContext.get().getMatrixClientAdapter().getRowsFlow(result, rowIndex, rpcBatchSize, stalnessClock);
        }
        return result;
    } else {
        // For ASYNC, just get rows from pss.
        IntOpenHashSet rowIdSet = rowIndex.getRowIds();
        List<Integer> rowIndexes = new ArrayList<Integer>(rowIdSet.size());
        rowIndexes.addAll(rowIdSet);
        GetRowsFunc func = new GetRowsFunc(new GetRowsParam(rowIndex.getMatrixId(), rowIndexes));
        com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult funcResult = ((com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult) PSAgentContext.get().getMatrixClientAdapter().get(func));
        if (funcResult.getResponseType() == ResponseType.FAILED) {
            throw new IOException("get rows from ps failed.");
        } else {
            Map<Integer, TVector> rows = funcResult.getRows();
            for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
                result.put(rowEntry.getValue());
            }
            result.fetchOver();
            return result;
        }
    }
}
Also used : GetRowsResult(com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult) ArrayList(java.util.ArrayList) IOException(java.io.IOException) IntOpenHashSet(it.unimi.dsi.fastutil.ints.IntOpenHashSet) GetRowsParam(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam) GetRowsFunc(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc) TVector(com.tencent.angel.ml.math.TVector)

Aggregations

TVector (com.tencent.angel.ml.math.TVector)2 GetRowsFunc (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc)2 GetRowsParam (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam)2 ArrayList (java.util.ArrayList)2 DenseIntVector (com.tencent.angel.ml.math.vector.DenseIntVector)1 GetRowsResult (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult)1 MatrixClient (com.tencent.angel.psagent.matrix.MatrixClient)1 GetRowsResult (com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult)1 Worker (com.tencent.angel.worker.Worker)1 IntOpenHashSet (it.unimi.dsi.fastutil.ints.IntOpenHashSet)1 IOException (java.io.IOException)1 Test (org.junit.Test)1