use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class AdamUpdateFunc method update.
@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
double gamma = scalars[0];
double epsilon = scalars[1];
double beta = scalars[2];
double lr = scalars[3];
double regParam = scalars[4];
double epoch = scalars[5];
double batchSize = scalars[6];
if (epoch == 0) {
epoch = 1;
}
double powBeta = Math.pow(beta, epoch);
double powGamma = Math.pow(gamma, epoch);
for (int f = 0; f < factor; f++) {
ServerRow gradientServerRow = partition.getRow(f + 3 * factor);
try {
gradientServerRow.startWrite();
Vector weight = ServerRowUtils.getVector(partition.getRow(f));
Vector velocity = ServerRowUtils.getVector(partition.getRow(f + factor));
Vector square = ServerRowUtils.getVector(partition.getRow(f + 2 * factor));
Vector gradient = ServerRowUtils.getVector(gradientServerRow);
if (batchSize > 1) {
gradient.idiv(batchSize);
}
if (regParam != 0.0) {
gradient.iaxpy(weight, regParam);
}
OptFuncs.iexpsmoothing(velocity, gradient, beta);
OptFuncs.iexpsmoothing2(square, gradient, gamma);
Vector delta = OptFuncs.adamdelta(velocity, square, powBeta, powGamma);
weight.iaxpy(delta, -lr);
gradient.clear();
} finally {
gradientServerRow.endWrite();
}
}
}
use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class MomentumUpdateFunc method update.
@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
double momentum = scalars[0];
double lr = scalars[1];
double regParam = scalars[2];
double batchSize = scalars[3];
for (int f = 0; f < factor; f++) {
ServerRow gradientServerRow = partition.getRow(f + 2 * factor);
try {
gradientServerRow.startWrite();
Vector weight = ServerRowUtils.getVector(partition.getRow(f));
Vector velocity = ServerRowUtils.getVector(partition.getRow(f + factor));
Vector gradient = ServerRowUtils.getVector(gradientServerRow);
if (batchSize > 1) {
gradient.idiv(batchSize);
}
if (regParam != 0.0) {
gradient.iaxpy(weight, regParam);
}
velocity.imul(momentum).iadd(gradient);
weight.isub(velocity.mul(lr));
gradient.clear();
} finally {
gradientServerRow.endWrite();
}
}
}
use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class GetPartFunc method partitionGet.
@Override
public PartitionGetResult partitionGet(PartitionGetParam partParam) {
if (partParam instanceof PartitionGetRowsParam) {
PartitionGetRowsParam param = (PartitionGetRowsParam) partParam;
PartitionKey pkey = param.getPartKey();
pkey = psContext.getMatrixMetaManager().getMatrixMeta(pkey.getMatrixId()).getPartitionMeta(pkey.getPartitionId()).getPartitionKey();
int ws = pkey.getStartRow();
int es = pkey.getEndRow();
List<Integer> reqRows = param.getRowIndexes();
MatrixStorageManager manager = psContext.getMatrixStorageManager();
List<ServerRow> rows = new ArrayList<>();
for (int w : reqRows) rows.add(manager.getRow(pkey, w));
PartCSRResult csr = new PartCSRResult(rows);
return csr;
} else {
return null;
}
}
use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class IncrementRows method getVector.
/**
* Get inner vector from server matrix, it is can be only use in RowBasedPartition and basic row
* type
*
* @param matrixId matrix id
* @param rowId row id
* @param part partition key
* @return inner vector
*/
protected Vector getVector(int matrixId, int rowId, PartitionKey part) {
ServerMatrix matrix = psContext.getMatrixStorageManager().getMatrix(matrixId);
ServerRow psRow = ((RowBasedPartition) matrix.getPartition(part.getPartitionId())).getRow(rowId);
return ServerRowUtils.getVector(psRow);
}
use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class PartitionGetRowsResult method deserialize.
@Override
public void deserialize(ByteBuf buf) {
int size = buf.readInt();
rowSplits = new ArrayList<ServerRow>(size);
for (int i = 0; i < size; i++) {
RowType type = RowType.valueOf(buf.readInt());
ServerRow rowSplit = ServerRowFactory.createEmptyServerRow(type);
rowSplit.deserialize(buf);
rowSplits.add(rowSplit);
}
}
Aggregations