use of com.tencent.angel.ml.math2.vector.Vector in project angel by Tencent.
the class RangeRouterUtils method splitIntIntVector.
public static KeyValuePart[] splitIntIntVector(MatrixMeta matrixMeta, IntIntVector vector) {
IntIntVectorStorage storage = vector.getStorage();
if (storage.isSparse()) {
// Get keys and values
IntIntSparseVectorStorage sparseStorage = (IntIntSparseVectorStorage) storage;
int[] keys = sparseStorage.getIndices();
int[] values = sparseStorage.getValues();
return split(matrixMeta, vector.getRowId(), keys, values, false);
} else if (storage.isDense()) {
// Get values
IntIntDenseVectorStorage denseStorage = (IntIntDenseVectorStorage) storage;
int[] values = denseStorage.getValues();
return split(matrixMeta, vector.getRowId(), values);
} else {
// Key and value array pair
IntIntSortedVectorStorage sortStorage = (IntIntSortedVectorStorage) storage;
int[] keys = sortStorage.getIndices();
int[] values = sortStorage.getValues();
return split(matrixMeta, vector.getRowId(), keys, values, true);
}
}
use of com.tencent.angel.ml.math2.vector.Vector in project angel by Tencent.
the class RangeRouterUtils method splitLongDoubleVector.
public static KeyValuePart[] splitLongDoubleVector(MatrixMeta matrixMeta, LongDoubleVector vector) {
LongDoubleVectorStorage storage = vector.getStorage();
if (storage.isSparse()) {
// Get keys and values
LongDoubleSparseVectorStorage sparseStorage = (LongDoubleSparseVectorStorage) storage;
long[] keys = sparseStorage.getIndices();
double[] values = sparseStorage.getValues();
return split(matrixMeta, vector.getRowId(), keys, values, false);
} else {
// Key and value array pair
LongDoubleSortedVectorStorage sortStorage = (LongDoubleSortedVectorStorage) storage;
long[] keys = sortStorage.getIndices();
double[] values = sortStorage.getValues();
return split(matrixMeta, vector.getRowId(), keys, values, true);
}
}
use of com.tencent.angel.ml.math2.vector.Vector in project angel by Tencent.
the class Min method partitionUpdate.
@Override
public void partitionUpdate(PartitionUpdateParam partParam) {
PartIncrementRowsParam param = (PartIncrementRowsParam) partParam;
List<RowUpdateSplit> updates = param.getUpdates();
for (RowUpdateSplit update : updates) {
ServerRow row = psContext.getMatrixStorageManager().getRow(param.getPartKey(), update.getRowId());
row.startWrite();
try {
Vector vector = getVector(param.getMatrixId(), update.getRowId(), param.getPartKey());
Ufuncs.imin(vector, update.getVector());
} finally {
row.endWrite();
}
}
}
use of com.tencent.angel.ml.math2.vector.Vector in project angel by Tencent.
the class GetRowHandler method handle.
@Override
public void handle(FutureResult finalResult, UserRequest userRequest, ResponseCache responseCache) {
GetRowRequest getRowRequest = (GetRowRequest) userRequest;
MapResponseCache cache = (MapResponseCache) responseCache;
// Merge the sub-response
List<ServerRow> serverRows = new ArrayList<>(cache.expectedResponseNum);
for (Response response : cache.getResponses().values()) {
serverRows.add(((GetRowSplitResponse) (response.getData())).getRowSplit());
}
Vector vector = MergeUtils.combineServerRowSplits(serverRows, getRowRequest.getMatrixId(), getRowRequest.getRowId());
// Set matrix/row information
vector.setMatrixId(getRowRequest.getMatrixId());
vector.setRowId(getRowRequest.getRowId());
// Set result
finalResult.set(vector);
}
use of com.tencent.angel.ml.math2.vector.Vector in project angel by Tencent.
the class AdaGradUpdateFunc method update.
@Override
public void update(RowBasedPartition partition, int factor, double[] scalars) {
double epsilon = scalars[0];
double beta = scalars[1];
double lr = scalars[2];
double l1RegParam = scalars[3];
double l2RegParam = scalars[4];
double epoch = (int) scalars[5];
double batchSize = (int) scalars[6];
for (int f = 0; f < factor; f++) {
ServerRow gradientServerRow = partition.getRow(f + 2 * factor);
try {
gradientServerRow.startWrite();
Vector weight = ServerRowUtils.getVector(partition.getRow(f));
Vector square = ServerRowUtils.getVector(partition.getRow(f + factor));
Vector gradient = ServerRowUtils.getVector(gradientServerRow);
if (batchSize > 1) {
gradient.idiv(batchSize);
}
OptFuncs.iexpsmoothing2(square, gradient, beta);
if (l2RegParam != 0) {
gradient.iaxpy(weight, l2RegParam);
}
OptFuncs.iadagraddelta(gradient, square, l2RegParam, lr);
weight.isub(gradient);
if (l1RegParam != 0) {
OptFuncs.iadagradthredshold(weight, square, l1RegParam, l2RegParam, lr);
}
gradient.clear();
} finally {
gradientServerRow.endWrite();
}
}
}
Aggregations