Search in sources :

Example 1 with GetRowsResult

use of com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult in project angel by Tencent.

the class TransportTest method testGetFlowDenseDoubleMatrix.

@Test
public void testGetFlowDenseDoubleMatrix() throws Exception {
    try {
        Worker worker = LocalClusterContext.get().getWorker(worker0Attempt0Id).getWorker();
        MatrixClient mat = worker.getPSAgent().getMatrixClient("dense_double_mat_1", 0);
        double[][] data = new double[ddRow][ddCol];
        DenseDoubleMatrix expect = new DenseDoubleMatrix(ddRow, ddCol, data);
        RowIndex rowIndex = new RowIndex();
        for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
        GetRowsResult result = mat.getRowsFlow(rowIndex, ddRow / 2);
        TVector row;
        while ((row = result.take()) != null) {
            LOG.info("===========get row index=" + row.getRowId());
            assertArrayEquals(((DenseDoubleVector) expect.getRow(row.getRowId())).getValues(), ((DenseDoubleVector) row).getValues(), 0.0);
        }
        Random rand = new Random(System.currentTimeMillis());
        for (int rowId = 0; rowId < ddRow; rowId++) {
            DenseDoubleVector update = new DenseDoubleVector(ddCol);
            for (int j = 0; j < ddCol; j += 3) update.set(j, rand.nextDouble());
            mat.increment(rowId, update);
            expect.getRow(rowId).plusBy(update);
        }
        mat.clock().get();
        rowIndex = new RowIndex();
        for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
        result = mat.getRowsFlow(rowIndex, 2);
        while ((row = result.take()) != null) {
            assertArrayEquals(((DenseDoubleVector) expect.getRow(row.getRowId())).getValues(), ((DenseDoubleVector) row).getValues(), 0.0);
        }
        rowIndex = new RowIndex();
        for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
        result = mat.getRowsFlow(rowIndex, 2);
        while (true) {
            row = result.poll();
            if (result.isFetchOver() && row == null)
                break;
            if (row == null)
                continue;
            assertArrayEquals(((DenseDoubleVector) expect.getRow(row.getRowId())).getValues(), ((DenseDoubleVector) row).getValues(), 0.0);
        }
    } catch (Exception x) {
        LOG.error("run testGetFlowDenseDoubleMatrix failed ", x);
        throw x;
    }
}
Also used : RowIndex(com.tencent.angel.psagent.matrix.transport.adapter.RowIndex) Random(java.util.Random) DenseDoubleVector(com.tencent.angel.ml.math.vector.DenseDoubleVector) GetRowsResult(com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult) DenseDoubleMatrix(com.tencent.angel.ml.math.matrix.DenseDoubleMatrix) MatrixClient(com.tencent.angel.psagent.matrix.MatrixClient) TVector(com.tencent.angel.ml.math.TVector) IOException(java.io.IOException) MasterServiceTest(com.tencent.angel.master.MasterServiceTest) Test(org.junit.Test)

Example 2 with GetRowsResult

use of com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult in project angel by Tencent.

the class TransportTest method testGetFlowDenseIntMatrix.

@Test
public void testGetFlowDenseIntMatrix() throws Exception {
    try {
        Worker worker = LocalClusterContext.get().getWorker(worker0Attempt0Id).getWorker();
        MatrixClient mat = worker.getPSAgent().getMatrixClient("dense_int_mat_1", 0);
        DenseIntMatrix expect = new DenseIntMatrix(diRow, diCol);
        RowIndex rowIndex = new RowIndex();
        for (int i = 0; i < diRow; i++) rowIndex.addRowId(i);
        GetRowsResult result = mat.getRowsFlow(rowIndex, diRow / 2);
        TVector row;
        while ((row = result.take()) != null) {
            assertArrayEquals(((DenseIntVector) expect.getRow(row.getRowId())).getValues(), ((DenseIntVector) row).getValues());
        }
        Random rand = new Random(System.currentTimeMillis());
        for (int rowId = 0; rowId < diRow; rowId++) {
            DenseIntVector update = new DenseIntVector(diCol);
            for (int j = 0; j < ddCol; j += 3) update.set(j, rand.nextInt());
            mat.increment(rowId, update);
            expect.getRow(rowId).plusBy(update);
        }
        mat.clock().get();
        rowIndex = new RowIndex();
        for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
        result = mat.getRowsFlow(rowIndex, 2);
        while ((row = result.take()) != null) {
            assertArrayEquals(((DenseIntVector) expect.getRow(row.getRowId())).getValues(), ((DenseIntVector) row).getValues());
        }
        rowIndex = new RowIndex();
        for (int i = 0; i < ddRow; i++) rowIndex.addRowId(i);
        result = mat.getRowsFlow(rowIndex, 2);
        while (true) {
            row = result.poll();
            if (result.isFetchOver() && row == null)
                break;
            if (row == null)
                continue;
            assertArrayEquals(((DenseIntVector) expect.getRow(row.getRowId())).getValues(), ((DenseIntVector) row).getValues());
        }
    } catch (Exception x) {
        LOG.error("run testGetFlowDenseIntMatrix failed ", x);
        throw x;
    }
}
Also used : DenseIntMatrix(com.tencent.angel.ml.math.matrix.DenseIntMatrix) RowIndex(com.tencent.angel.psagent.matrix.transport.adapter.RowIndex) Random(java.util.Random) GetRowsResult(com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult) MatrixClient(com.tencent.angel.psagent.matrix.MatrixClient) TVector(com.tencent.angel.ml.math.TVector) IOException(java.io.IOException) DenseIntVector(com.tencent.angel.ml.math.vector.DenseIntVector) MasterServiceTest(com.tencent.angel.master.MasterServiceTest) Test(org.junit.Test)

Example 3 with GetRowsResult

use of com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult in project angel by Tencent.

the class ConsistencyController method getRowsFlow.

/**
 * Get a batch of row from storage/cache or pss.
 *
 * @param taskContext task context
 * @param rowIndex row indexes
 * @param rpcBatchSize fetch row number in one rpc request
 * @return GetRowsResult rows
 * @throws Exception
 */
public GetRowsResult getRowsFlow(TaskContext taskContext, RowIndex rowIndex, int rpcBatchSize) throws Exception {
    GetRowsResult result = new GetRowsResult();
    if (rowIndex.getRowsNumber() == 0) {
        LOG.error("need get rowId set is empty, just return");
        result.fetchOver();
        return result;
    }
    int staleness = getStaleness(rowIndex.getMatrixId());
    if (staleness >= 0) {
        // For BSP/SSP, get rows from storage/cache first
        int stalnessClock = taskContext.getMatrixClock(rowIndex.getMatrixId()) - staleness;
        findRowsInStorage(taskContext, result, rowIndex, stalnessClock);
        if (!result.isFetchOver()) {
            LOG.debug("need fetch from parameterserver");
            // Get from ps.
            PSAgentContext.get().getMatrixClientAdapter().getRowsFlow(result, rowIndex, rpcBatchSize, stalnessClock);
        }
        return result;
    } else {
        // For ASYNC, just get rows from pss.
        IntOpenHashSet rowIdSet = rowIndex.getRowIds();
        List<Integer> rowIndexes = new ArrayList<Integer>(rowIdSet.size());
        rowIndexes.addAll(rowIdSet);
        GetRowsFunc func = new GetRowsFunc(new GetRowsParam(rowIndex.getMatrixId(), rowIndexes));
        com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult funcResult = ((com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult) PSAgentContext.get().getMatrixClientAdapter().get(func));
        if (funcResult.getResponseType() == ResponseType.FAILED) {
            throw new IOException("get rows from ps failed.");
        } else {
            Map<Integer, TVector> rows = funcResult.getRows();
            for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
                result.put(rowEntry.getValue());
            }
            result.fetchOver();
            return result;
        }
    }
}
Also used : GetRowsResult(com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult) ArrayList(java.util.ArrayList) IOException(java.io.IOException) IntOpenHashSet(it.unimi.dsi.fastutil.ints.IntOpenHashSet) GetRowsParam(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam) GetRowsFunc(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc) TVector(com.tencent.angel.ml.math.TVector)

Aggregations

TVector (com.tencent.angel.ml.math.TVector)3 GetRowsResult (com.tencent.angel.psagent.matrix.transport.adapter.GetRowsResult)3 IOException (java.io.IOException)3 MasterServiceTest (com.tencent.angel.master.MasterServiceTest)2 MatrixClient (com.tencent.angel.psagent.matrix.MatrixClient)2 RowIndex (com.tencent.angel.psagent.matrix.transport.adapter.RowIndex)2 Random (java.util.Random)2 Test (org.junit.Test)2 DenseDoubleMatrix (com.tencent.angel.ml.math.matrix.DenseDoubleMatrix)1 DenseIntMatrix (com.tencent.angel.ml.math.matrix.DenseIntMatrix)1 DenseDoubleVector (com.tencent.angel.ml.math.vector.DenseDoubleVector)1 DenseIntVector (com.tencent.angel.ml.math.vector.DenseIntVector)1 GetRowsFunc (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc)1 GetRowsParam (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam)1 IntOpenHashSet (it.unimi.dsi.fastutil.ints.IntOpenHashSet)1 ArrayList (java.util.ArrayList)1