Search in sources :

Example 31 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class AdamUpdateFunc method update.

@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
    double gamma = scalars[0];
    double epsilon = scalars[1];
    double beta = scalars[2];
    double lr = scalars[3];
    double regParam = scalars[4];
    double epoch = scalars[5];
    double batchSize = scalars[6];
    if (epoch == 0) {
        epoch = 1;
    }
    double powBeta = Math.pow(beta, epoch);
    double powGamma = Math.pow(gamma, epoch);
    for (int f = 0; f < factor; f++) {
        ServerRow gradientServerRow = partition.getRow(f + 3 * factor);
        try {
            gradientServerRow.startWrite();
            Vector weight = ServerRowUtils.getVector(partition.getRow(f));
            Vector velocity = ServerRowUtils.getVector(partition.getRow(f + factor));
            Vector square = ServerRowUtils.getVector(partition.getRow(f + 2 * factor));
            Vector gradient = ServerRowUtils.getVector(gradientServerRow);
            if (batchSize > 1) {
                gradient.idiv(batchSize);
            }
            if (regParam != 0.0) {
                gradient.iaxpy(weight, regParam);
            }
            OptFuncs.iexpsmoothing(velocity, gradient, beta);
            OptFuncs.iexpsmoothing2(square, gradient, gamma);
            Vector delta = OptFuncs.adamdelta(velocity, square, powBeta, powGamma);
            weight.iaxpy(delta, -lr);
            gradient.clear();
        } finally {
            gradientServerRow.endWrite();
        }
    }
}
Also used : ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) Vector(com.tencent.angel.ml.math2.vector.Vector)

Example 32 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class MomentumUpdateFunc method update.

@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
    double momentum = scalars[0];
    double lr = scalars[1];
    double regParam = scalars[2];
    double batchSize = scalars[3];
    for (int f = 0; f < factor; f++) {
        ServerRow gradientServerRow = partition.getRow(f + 2 * factor);
        try {
            gradientServerRow.startWrite();
            Vector weight = ServerRowUtils.getVector(partition.getRow(f));
            Vector velocity = ServerRowUtils.getVector(partition.getRow(f + factor));
            Vector gradient = ServerRowUtils.getVector(gradientServerRow);
            if (batchSize > 1) {
                gradient.idiv(batchSize);
            }
            if (regParam != 0.0) {
                gradient.iaxpy(weight, regParam);
            }
            velocity.imul(momentum).iadd(gradient);
            weight.isub(velocity.mul(lr));
            gradient.clear();
        } finally {
            gradientServerRow.endWrite();
        }
    }
}
Also used : ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) Vector(com.tencent.angel.ml.math2.vector.Vector)

Example 33 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class GetPartFunc method partitionGet.

@Override
public PartitionGetResult partitionGet(PartitionGetParam partParam) {
    if (partParam instanceof PartitionGetRowsParam) {
        PartitionGetRowsParam param = (PartitionGetRowsParam) partParam;
        PartitionKey pkey = param.getPartKey();
        pkey = psContext.getMatrixMetaManager().getMatrixMeta(pkey.getMatrixId()).getPartitionMeta(pkey.getPartitionId()).getPartitionKey();
        int ws = pkey.getStartRow();
        int es = pkey.getEndRow();
        List<Integer> reqRows = param.getRowIndexes();
        MatrixStorageManager manager = psContext.getMatrixStorageManager();
        List<ServerRow> rows = new ArrayList<>();
        for (int w : reqRows) rows.add(manager.getRow(pkey, w));
        PartCSRResult csr = new PartCSRResult(rows);
        return csr;
    } else {
        return null;
    }
}
Also used : PartitionGetRowsParam(com.tencent.angel.ml.matrix.psf.get.getrows.PartitionGetRowsParam) MatrixStorageManager(com.tencent.angel.ps.storage.MatrixStorageManager) ArrayList(java.util.ArrayList) PartitionKey(com.tencent.angel.PartitionKey) ServerRow(com.tencent.angel.ps.storage.vector.ServerRow)

Example 34 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class IncrementRows method getVector.

/**
 * Get inner vector from server matrix, it is can be only use in RowBasedPartition and basic row
 * type
 *
 * @param matrixId matrix id
 * @param rowId row id
 * @param part partition key
 * @return inner vector
 */
protected Vector getVector(int matrixId, int rowId, PartitionKey part) {
    ServerMatrix matrix = psContext.getMatrixStorageManager().getMatrix(matrixId);
    ServerRow psRow = ((RowBasedPartition) matrix.getPartition(part.getPartitionId())).getRow(rowId);
    return ServerRowUtils.getVector(psRow);
}
Also used : ServerMatrix(com.tencent.angel.ps.storage.matrix.ServerMatrix) ServerRow(com.tencent.angel.ps.storage.vector.ServerRow) RowBasedPartition(com.tencent.angel.ps.storage.partition.RowBasedPartition)

Example 35 with ServerRow

use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.

the class PartitionGetRowsResult method deserialize.

@Override
public void deserialize(ByteBuf buf) {
    int size = buf.readInt();
    rowSplits = new ArrayList<ServerRow>(size);
    for (int i = 0; i < size; i++) {
        RowType type = RowType.valueOf(buf.readInt());
        ServerRow rowSplit = ServerRowFactory.createEmptyServerRow(type);
        rowSplit.deserialize(buf);
        rowSplits.add(rowSplit);
    }
}
Also used : RowType(com.tencent.angel.ml.matrix.RowType) ServerRow(com.tencent.angel.ps.storage.vector.ServerRow)

Aggregations

ServerRow (com.tencent.angel.ps.storage.vector.ServerRow)42 Vector (com.tencent.angel.ml.math2.vector.Vector)13 RowBasedPartition (com.tencent.angel.ps.storage.partition.RowBasedPartition)6 PartIncrementRowsParam (com.tencent.angel.ml.matrix.psf.update.update.PartIncrementRowsParam)5 RowUpdateSplit (com.tencent.angel.psagent.matrix.oplog.cache.RowUpdateSplit)5 ArrayList (java.util.ArrayList)5 PartitionKey (com.tencent.angel.PartitionKey)4 ServerIntIntRow (com.tencent.angel.ps.storage.vector.ServerIntIntRow)4 AngelException (com.tencent.angel.exception.AngelException)2 FloatVector (com.tencent.angel.ml.math2.vector.FloatVector)2 RowType (com.tencent.angel.ml.matrix.RowType)2 PartitionGetRowsParam (com.tencent.angel.ml.matrix.psf.get.getrows.PartitionGetRowsParam)2 GetRowSplitResponse (com.tencent.angel.ps.server.data.response.GetRowSplitResponse)2 GetRowsSplitResponse (com.tencent.angel.ps.server.data.response.GetRowsSplitResponse)2 Response (com.tencent.angel.ps.server.data.response.Response)2 MatrixStorageManager (com.tencent.angel.ps.storage.MatrixStorageManager)2 ServerMatrix (com.tencent.angel.ps.storage.matrix.ServerMatrix)2 ServerPartition (com.tencent.angel.ps.storage.partition.ServerPartition)2 ServerRowsStorage (com.tencent.angel.ps.storage.partition.storage.ServerRowsStorage)2 KeyPart (com.tencent.angel.psagent.matrix.transport.router.KeyPart)2