Search in sources :

Example 11 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class GetColumnFunc method partitionGet.

@Override
public PartitionGetResult partitionGet(PartitionGetParam partParam) {
    if (partParam instanceof PartitionGetRowsParam) {
        PartitionGetRowsParam param = (PartitionGetRowsParam) partParam;
        PartitionKey pkey = param.getPartKey();
        pkey = psContext.getMatrixMetaManager().getMatrixMeta(pkey.getMatrixId()).getPartitionMeta(pkey.getPartitionId()).getPartitionKey();
        List<Integer> reqCols = param.getRowIndexes();
        int start = reqCols.get(0);
        int end = reqCols.get(1);
        MatrixStorageManager manager = psContext.getMatrixStorageManager();
        Map<Integer, Int2IntOpenHashMap> cks = new HashMap();
        for (int col = start; col < end; col++) cks.put(col, new Int2IntOpenHashMap());
        int rowOffset = pkey.getStartRow();
        int rowLength = pkey.getEndRow();
        for (int r = rowOffset; r < rowLength; r++) {
            ServerRow row = manager.getRow(pkey, r);
            if (row instanceof ServerIntIntRow) {
                for (int col = start; col < end; col++) {
                    Int2IntOpenHashMap map = cks.get(col);
                    int k = ((ServerIntIntRow) row).get(col);
                    if (k > 0)
                        map.put(row.getRowId(), k);
                }
            }
        }
        return new PartColumnResult(cks);
    } else {
        return null;
    }
}
Also used : Int2IntOpenHashMap(it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap) HashMap(java.util.HashMap) PartitionGetRowsParam(com.tencent.angel.ml.matrix.psf.get.getrows.PartitionGetRowsParam) MatrixStorageManager(com.tencent.angel.ps.storage.MatrixStorageManager) PartitionKey(com.tencent.angel.PartitionKey) ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) ServerIntIntRow(com.tencent.angel.ps.storage.vector.ServerIntIntRow) Int2IntOpenHashMap(it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap)

Example 12 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class PGDUpdateFunc method update.

@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
    double lr = scalars[0];
    double l1RegParam = scalars[1];
    double l2RegParam = scalars[2];
    double batchSize = (int) scalars[3];
    for (int f = 0; f < factor; f++) {
        ServerRow gradientServerRow = partition.getRow(f + factor);
        try {
            gradientServerRow.startWrite();
            Vector weight = ServerRowUtils.getVector(partition.getRow(f));
            Vector gradient = ServerRowUtils.getVector(gradientServerRow);
            if (batchSize > 1) {
                gradient.idiv(batchSize);
            }
            double lrTemp = lr / (1 + l2RegParam * lr);
            if (l2RegParam != 0.0) {
                weight.imul(1 - lrTemp * l2RegParam).iaxpy(gradient, -lrTemp);
            } else {
                weight.iaxpy(gradient, -lrTemp);
            }
            if (l1RegParam != 0) {
                Ufuncs.isoftthreshold(weight, lrTemp * l1RegParam);
            }
            gradient.clear();
        } finally {
            gradientServerRow.endWrite();
        }
    }
}
Also used : ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) Vector(com.tencent.angel.ml.math2.vector.Vector)

Example 13 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class GBDTGradHistGetRowFunc method merge.

@Override
public GetResult merge(List<PartitionGetResult> partResults) {
    int size = partResults.size();
    List<ServerRow> rowSplits = new ArrayList<ServerRow>(size);
    for (int i = 0; i < size; i++) {
        rowSplits.add(((PartitionGetRowResult) partResults.get(i)).getRowSplit());
    }
    SplitEntry splitEntry = new SplitEntry();
    for (int i = 0; i < size; i++) {
        ServerIntDoubleRow row = (ServerIntDoubleRow) ((PartitionGetRowResult) partResults.get(i)).getRowSplit();
        int fid = (int) row.get(0 + (int) row.getStartCol());
        if (fid != -1) {
            int splitIndex = (int) row.get(1 + (int) row.getStartCol());
            float lossGain = (float) row.get(2 + (int) row.getStartCol());
            float leftSumGrad = (float) row.get(3 + (int) row.getStartCol());
            float leftSumHess = (float) row.get(4 + (int) row.getStartCol());
            float rightSumGrad = (float) row.get(5 + (int) row.getStartCol());
            float rightSumHess = (float) row.get(6 + (int) row.getStartCol());
            LOG.info(String.format("psFunc: the best split after looping a split: fid[%d], fvalue[%d], loss gain[%f]" + ", leftSumGrad[%f], leftSumHess[%f], rightSumGrad[%f], rightSumHess[%f]", fid, splitIndex, lossGain, leftSumGrad, leftSumHess, rightSumGrad, rightSumHess));
            GradStats curLeftGradStat = new GradStats(leftSumGrad, leftSumHess);
            GradStats curRightGradStat = new GradStats(rightSumGrad, rightSumHess);
            SplitEntry curSplitEntry = new SplitEntry(fid, splitIndex, lossGain);
            curSplitEntry.leftGradStat = curLeftGradStat;
            curSplitEntry.rightGradStat = curRightGradStat;
            splitEntry.update(curSplitEntry);
        }
    }
    return new GBDTGradHistGetRowResult(ResponseType.SUCCESS, splitEntry);
}
Also used : ServerIntDoubleRow(com.tencent.angel.ps.storage.vector.ServerIntDoubleRow) SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry) ArrayList(java.util.ArrayList) GradStats(com.tencent.angel.ml.GBDT.algo.RegTree.GradStats) ServerRow(com.tencent.angel.ps.storage.vector.ServerRow)

Example 14 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class MyIncrement method partitionUpdate.

@Override
public void partitionUpdate(PartitionUpdateParam partParam) {
    PartIncrementRowsParam param = (PartIncrementRowsParam) partParam;
    List<RowUpdateSplit> updates = param.getUpdates();
    for (RowUpdateSplit update : updates) {
        ServerRow row = psContext.getMatrixStorageManager().getRow(param.getPartKey(), update.getRowId());
        row.startWrite();
        try {
            Vector vector = getVector(param.getMatrixId(), update.getRowId(), param.getPartKey());
            vector.iadd(update.getVector());
        } finally {
            row.endWrite();
        }
    }
}
Also used : PartIncrementRowsParam(com.tencent.angel.ml.matrix.psf.update.update.PartIncrementRowsParam) RowUpdateSplit(com.tencent.angel.psagent.matrix.oplog.cache.RowUpdateSplit) ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) Vector(com.tencent.angel.ml.math2.vector.Vector)

Example 15 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class MyPull method partitionGet.

@Override
public PartitionGetResult partitionGet(PartitionGetParam partParam) {
    MyPullPartParam param = (MyPullPartParam) partParam;
    ServerRow msgRows = psContext.getMatrixStorageManager().getRow(param.getPartKey(), param.getDeltaId());
    ServerRow sumsRow = psContext.getMatrixStorageManager().getRow(param.getPartKey(), param.getSumId());
    long start = param.getPartKey().getStartCol();
    long range = param.getPartKey().getEndCol() - start;
    FloatVector msgs = ServerRowUtils.getVector((ServerLongFloatRow) msgRows);
    FloatVector sums = ServerRowUtils.getVector((ServerLongFloatRow) sumsRow);
    return new MyPullPartResult(param.getKeys(), start, msgs, sums, param.getResetProb(), param.getTol());
}
Also used : ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) FloatVector(com.tencent.angel.ml.math2.vector.FloatVector)

Aggregations

ServerRow (com.tencent.angel.ps.storage.vector.ServerRow)42 Vector (com.tencent.angel.ml.math2.vector.Vector)13 RowBasedPartition (com.tencent.angel.ps.storage.partition.RowBasedPartition)6 PartIncrementRowsParam (com.tencent.angel.ml.matrix.psf.update.update.PartIncrementRowsParam)5 RowUpdateSplit (com.tencent.angel.psagent.matrix.oplog.cache.RowUpdateSplit)5 ArrayList (java.util.ArrayList)5 PartitionKey (com.tencent.angel.PartitionKey)4 ServerIntIntRow (com.tencent.angel.ps.storage.vector.ServerIntIntRow)4 AngelException (com.tencent.angel.exception.AngelException)2 FloatVector (com.tencent.angel.ml.math2.vector.FloatVector)2 RowType (com.tencent.angel.ml.matrix.RowType)2 PartitionGetRowsParam (com.tencent.angel.ml.matrix.psf.get.getrows.PartitionGetRowsParam)2 GetRowSplitResponse (com.tencent.angel.ps.server.data.response.GetRowSplitResponse)2 GetRowsSplitResponse (com.tencent.angel.ps.server.data.response.GetRowsSplitResponse)2 Response (com.tencent.angel.ps.server.data.response.Response)2 MatrixStorageManager (com.tencent.angel.ps.storage.MatrixStorageManager)2 ServerMatrix (com.tencent.angel.ps.storage.matrix.ServerMatrix)2 ServerPartition (com.tencent.angel.ps.storage.partition.ServerPartition)2 ServerRowsStorage (com.tencent.angel.ps.storage.partition.storage.ServerRowsStorage)2 KeyPart (com.tencent.angel.psagent.matrix.transport.router.KeyPart)2