use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class FTRLUpdateFunc method update.
@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
double alpha = scalars[0];
double beta = scalars[1];
double lambda1 = scalars[2];
double lambda2 = scalars[3];
int epoch = (int) scalars[4];
int batchSize = (int) scalars[5];
for (int f = 0; f < factor; f++) {
ServerRow gradientServerRow = partition.getRow(f + 3 * factor);
try {
gradientServerRow.startWrite();
Vector weight = ServerRowUtils.getVector(partition.getRow(f));
Vector zModel = ServerRowUtils.getVector(partition.getRow(f + factor));
Vector nModel = ServerRowUtils.getVector(partition.getRow(f + 2 * factor));
Vector gradient = ServerRowUtils.getVector(gradientServerRow);
if (batchSize > 1) {
gradient.idiv(batchSize);
}
Vector delta = OptFuncs.ftrldelta(nModel, gradient, alpha);
Ufuncs.iaxpy2(nModel, gradient, 1);
zModel.iadd(gradient.sub(delta.mul(weight)));
Vector newWeight = Ufuncs.ftrlthreshold(zModel, nModel, alpha, beta, lambda1, lambda2);
weight.setStorage(newWeight.getStorage());
gradient.clear();
} finally {
gradientServerRow.endWrite();
}
}
}
use of com.tencent.angel.ps.storage.vector.ServerRow in project angel by Tencent.
the class StreamSerdeUtils method deserializeServerRow.
public static ServerRow deserializeServerRow(DataInputStream in) throws IOException {
ServerRow rowSplit = ServerRowFactory.createEmptyServerRow(RowType.valueOf(in.readInt()));
rowSplit.deserialize(in);
return rowSplit;
}
Aggregations