use of com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult in project angel by Tencent.
the class Dot method merge.
@Override
public GetResult merge(List<PartitionGetResult> partResults) {
if (partResults.size() > 0 && partResults.get(0) instanceof DotPartitionResult) {
int size = ((DotPartitionResult) partResults.get(0)).length;
// check the length of dot values
for (PartitionGetResult result : partResults) {
if (result instanceof DotPartitionResult && size != ((DotPartitionResult) result).length)
throw new AngelException(String.format("length of dot values not same one is %d other is %d", size, ((DotPartitionResult) result).length));
}
// merge dot values from all partitions
float[] results = new float[size];
for (PartitionGetResult result : partResults) if (result instanceof DotPartitionResult)
try {
((DotPartitionResult) result).merge(results);
} finally {
((DotPartitionResult) result).clear();
}
return new NEDot.NEDotResult(results);
}
return null;
}
use of com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult in project angel by Tencent.
the class GetNeighborsWithCount method merge.
@Override
public GetResult merge(List<PartitionGetResult> partResults) {
int resultSize = 0;
for (PartitionGetResult result : partResults) {
resultSize += ((PartGetNeighborWithCountResult) result).getNodeIds().length;
}
Long2ObjectOpenHashMap<long[]> nodeIdToNeighbors = new Long2ObjectOpenHashMap<>(resultSize);
for (PartitionGetResult result : partResults) {
PartGetNeighborWithCountResult getResult = (PartGetNeighborWithCountResult) result;
long[] nodeIds = getResult.getNodeIds();
long[][] objs = getResult.getData();
for (int i = 0; i < nodeIds.length; i++) {
nodeIdToNeighbors.put(nodeIds[i], objs[i]);
}
}
return new GetLongsResult(nodeIdToNeighbors);
}
use of com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult in project angel by Tencent.
the class GetHyperLogLog method merge.
@Override
public GetResult merge(List<PartitionGetResult> partResults) {
int size = 0;
for (PartitionGetResult result : partResults) {
size += ((GetHyperLogLogPartResult) result).getLogs().size();
}
Long2ObjectOpenHashMap<HyperLogLogPlus> logs = new Long2ObjectOpenHashMap<>(size);
for (PartitionGetResult r : partResults) {
GetHyperLogLogPartResult rr = (GetHyperLogLogPartResult) r;
logs.putAll(rr.getLogs());
}
return new GetHyperLogLogResult(logs);
}
use of com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult in project angel by Tencent.
the class GetUDFHandler method handle.
@Override
public void handle(FutureResult finalResult, UserRequest userRequest, ResponseCache responseCache) {
GetPSFRequest getPSFRequest = (GetPSFRequest) userRequest;
MapResponseCache cache = (MapResponseCache) responseCache;
// Adaptor to Get PSF merge
ConcurrentHashMap<Request, Response> responses = cache.getResponses();
int responseNum = responses.size();
List<PartitionGetResult> partGetResults = new ArrayList<>(responseNum);
for (Response response : responses.values()) {
partGetResults.add(((GetUDFResponse) response.getData()).getPartResult());
}
// Merge the sub-results
try {
finalResult.set(getPSFRequest.getGetFunc().merge(partGetResults));
} catch (Exception x) {
LOG.error("merge row failed ", x);
finalResult.setExecuteException(new ExecutionException(x));
}
}
use of com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult in project angel by Tencent.
the class GetNodes method merge.
@Override
public GetResult merge(List<PartitionGetResult> partResults) {
int size = 0;
for (PartitionGetResult result : partResults) {
if (result instanceof IndexPartGetLongResult) {
size += ((IndexPartGetLongResult) result).getValues().length;
}
}
long[] values = new long[size];
int start = 0;
for (PartitionGetResult result : partResults) {
if (result instanceof IndexPartGetLongResult) {
long[] vals = ((IndexPartGetLongResult) result).getValues();
System.arraycopy(vals, 0, values, start, vals.length);
start += vals.length;
}
}
return new GetRowResult(ResponseType.SUCCESS, VFactory.denseLongVector(values));
}
Aggregations