Search in sources :

Example 1 with GetRowsResult

use of com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult in project angel by Tencent.

the class MatrixOpLogTest method testUDF.

@Test
public void testUDF() throws ServiceException, IOException, InvalidParameterException, AngelException, InterruptedException, ExecutionException {
    Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
    MatrixClient w1Task0Client = worker.getPSAgent().getMatrixClient("w1", 0);
    MatrixClient w1Task1Client = worker.getPSAgent().getMatrixClient("w1", 1);
    int matrixW1Id = w1Task0Client.getMatrixId();
    List<Integer> rowIndexes = new ArrayList<Integer>();
    for (int i = 0; i < 100; i++) {
        rowIndexes.add(i);
    }
    GetRowsFunc func = new GetRowsFunc(new GetRowsParam(matrixW1Id, rowIndexes));
    int[] delta = new int[100000];
    for (int i = 0; i < 100000; i++) {
        delta[i] = 1;
    }
    // DenseIntVector deltaVec = new DenseIntVector(100000, delta);
    // deltaVec.setMatrixId(matrixW1Id);
    // deltaVec.setRowId(0);
    int index = 0;
    while (index++ < 10) {
        Map<Integer, TVector> rows = ((GetRowsResult) w1Task0Client.get(func)).getRows();
        for (Entry<Integer, TVector> rowEntry : rows.entrySet()) {
            LOG.info("index " + rowEntry.getKey() + " sum of w1 = " + sum((DenseIntVector) rowEntry.getValue()));
        }
        for (int i = 0; i < 100; i++) {
            DenseIntVector deltaVec = new DenseIntVector(100000, delta);
            deltaVec.setMatrixId(matrixW1Id);
            deltaVec.setRowId(i);
            w1Task0Client.increment(deltaVec);
            deltaVec = new DenseIntVector(100000, delta);
            deltaVec.setMatrixId(matrixW1Id);
            deltaVec.setRowId(i);
            w1Task1Client.increment(deltaVec);
        }
        w1Task0Client.clock().get();
        w1Task1Client.clock().get();
    }
}
Also used : GetRowsResult(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult) ArrayList(java.util.ArrayList) DenseIntVector(com.tencent.angel.ml.math.vector.DenseIntVector) GetRowsParam(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam) GetRowsFunc(com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc) Worker(com.tencent.angel.worker.Worker) MatrixClient(com.tencent.angel.psagent.matrix.MatrixClient) TVector(com.tencent.angel.ml.math.TVector) Test(org.junit.Test)

Aggregations

TVector (com.tencent.angel.ml.math.TVector)1 DenseIntVector (com.tencent.angel.ml.math.vector.DenseIntVector)1 GetRowsFunc (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsFunc)1 GetRowsParam (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsParam)1 GetRowsResult (com.tencent.angel.ml.matrix.psf.get.multi.GetRowsResult)1 MatrixClient (com.tencent.angel.psagent.matrix.MatrixClient)1 Worker (com.tencent.angel.worker.Worker)1 ArrayList (java.util.ArrayList)1 Test (org.junit.Test)1