use of com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam in project angel by Tencent.
the class LongKeysUpdateParam method split.
@Override
public List<PartitionUpdateParam> split() {
MatrixMeta meta = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(matrixId);
PartitionKey[] parts = meta.getPartitionKeys();
KeyValuePart[] splits = RouterUtils.split(meta, 0, nodeIds, neighbors);
assert parts.length == splits.length;
List<PartitionUpdateParam> partParams = new ArrayList<>(parts.length);
for (int i = 0; i < parts.length; i++) {
if (splits[i] != null && splits[i].size() > 0) {
partParams.add(new GeneralPartUpdateParam(matrixId, parts[i], splits[i]));
}
}
return partParams;
}
use of com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam in project angel by Tencent.
the class UpdateColsParam method split.
@Override
public List<PartitionUpdateParam> split() {
List<PartitionKey> pkeys = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId);
List<PartitionUpdateParam> params = new ArrayList<>();
int start = 0, end = 0;
for (PartitionKey pkey : pkeys) {
long startCol = pkey.getStartCol();
long endCol = pkey.getEndCol();
if (start < ((IntKeyVector) cols).getDim() && VectorUtils.getLong(cols, start) >= startCol) {
while (end < ((IntKeyVector) cols).getDim() && VectorUtils.getLong(cols, end) < endCol) end++;
long[] part = new long[end - start];
if (cols instanceof IntIntVector) {
ArrayCopy.copy(((IntIntVector) cols).getStorage().getValues(), start, part, 0, end - start);
} else {
System.arraycopy(((IntLongVector) cols).getStorage().getValues(), start, part, 0, end - start);
}
long firstKey = 0l;
for (Map.Entry<Long, Vector> first : values.entrySet()) {
firstKey = first.getKey();
break;
}
if (values.get(firstKey) instanceof IntDoubleVector) {
IntDoubleVector[] updates = new IntDoubleVector[part.length];
for (int i = 0; i < part.length; i++) updates[i] = (IntDoubleVector) values.get(part[i]);
params.add(new PartitionUpdateColsParam(matrixId, pkey, rows, part, VFactory.compIntDoubleVector(rows.length, updates, part.length), op));
} else if (values.get(firstKey) instanceof IntFloatVector) {
IntFloatVector[] updates = new IntFloatVector[part.length];
for (int i = 0; i < part.length; i++) updates[i] = (IntFloatVector) values.get(part[i]);
params.add(new PartitionUpdateColsParam(matrixId, pkey, rows, part, VFactory.compIntFloatVector(rows.length, updates, part.length), op));
} else {
throw new AngelException("Update data type should be float or double!");
}
start = end;
}
}
return params;
}
use of com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam in project angel by Tencent.
the class QuantifyFloatParam method split.
@Override
public List<PartitionUpdateParam> split() {
List<PartitionKey> partList = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId, rowId);
int size = partList.size();
List<PartitionUpdateParam> partParams = new ArrayList<>(size);
for (PartitionKey part : partList) {
if (rowId < part.getStartRow() || rowId >= part.getEndRow()) {
throw new RuntimeException("Wrong rowId!");
}
partParams.add(new QuantifyFloatPartParam(matrixId, part, rowId, (int) part.getStartCol(), (int) part.getEndCol(), array, numBits));
}
return partParams;
}
use of com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam in project angel by Tencent.
the class InitWalkPathParam method split.
@Override
public List<PartitionUpdateParam> split() {
List<PartitionKey> parts = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId);
int size = parts.size();
List<PartitionUpdateParam> partParams = new ArrayList<>(size);
for (PartitionKey part : parts) {
partParams.add(new InitWalkPathPartitionParam(matrixId, part, neighborMatrixId, walkLength, numParts, threshold, keepProba, isTrunc));
}
return partParams;
}
use of com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam in project angel by Tencent.
the class PushPathTailParam method split.
@Override
public List<PartitionUpdateParam> split() {
List<PartitionKey> parts = PSAgentContext.get().getMatrixMetaManager().getPartitions(matrixId);
List<PartitionUpdateParam> partParams = new ArrayList<>(parts.size());
if (!RowUpdateSplitUtils.isInRange(keyIds, parts)) {
throw new AngelException("node id is not in range [" + parts.get(0).getStartCol() + ", " + parts.get(parts.size() - 1).getEndCol());
}
int nodeIndex = 0;
for (PartitionKey part : parts) {
// include start
int start = nodeIndex;
while (nodeIndex < keyIds.length && keyIds[nodeIndex] < part.getEndCol()) {
nodeIndex++;
}
// exclude end
int end = nodeIndex;
int sizePart = end - start;
if (sizePart > 0) {
partParams.add(new PushPathTailPartitionParam(matrixId, part, pathTail, keyIds, start, end));
}
}
return partParams;
}
Aggregations