use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.
the class AggrFuncTest method testAmax.
@Test
public void testAmax() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Amax(w2Client.getMatrixId(), 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
double max = Double.MIN_VALUE;
for (double x : localArray1) {
if (max < Math.abs(x))
max = Math.abs(x);
}
Assert.assertEquals(result, max, delta);
}
use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.
the class SampleNeighbor method merge.
@Override
public GetResult merge(List<PartitionGetResult> partResults) {
SampleNeighborParam param = (SampleNeighborParam) getParam();
long[] keys = param.getKeys();
Int2ObjectArrayMap<PartitionGetResult> partIdToResult = new Int2ObjectArrayMap<>();
for (PartitionGetResult result : partResults) {
partIdToResult.put(((SampleNeighborPartResult) result).getPartId(), result);
}
for (PartitionGetParam partParam : param.getParams()) {
SampleNeighborPartParam param0 = (SampleNeighborPartParam) partParam;
int start = param0.getStartIndex();
int end = param0.getEndIndex();
SampleNeighborPartResult result = (SampleNeighborPartResult) partIdToResult.get(param0.getPartKey().getPartitionId());
int[] indptr = result.getIndptr();
long[] neighbors = result.getNeighbors();
int[] sampleTypes = result.getTypes();
assert indptr.length == (end - start) + 1;
for (int i = start; i < end; i++) {
int keyIndex = index.get(keys[i]);
for (int j = indptr[i - start]; j < indptr[i - start + 1]; j++) {
long n = neighbors[j];
if (!index.containsKey(n)) {
index.put(n, index.size());
}
srcs.add(keyIndex);
dsts.add(index.get(n));
}
if (param.getSampleTypes()) {
for (int j = indptr[i - start]; j < indptr[i - start + 1]; j++) {
types.add(sampleTypes[j]);
}
}
}
}
return new ScalarAggrResult(0);
}
use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.
the class AggrFuncTest method testDot.
@Test
public void testDot() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Dot(w2Client.getMatrixId(), 0, 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
double dot = 0.0;
for (int i = 0; i < dim; i++) {
dot += localArray0[i] * localArray1[i];
}
Assert.assertEquals(result, dot, delta);
}
use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.
the class AggrFuncTest method testNnz.
@Test
public void testNnz() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Nnz(w2Client.getMatrixId(), 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
int count = 0;
for (double x : localArray1) {
if (Math.abs(x - 0.0) > delta)
count++;
}
Assert.assertEquals((int) result, count);
}
use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.
the class AggrFuncTest method testSum.
@Test
public void testSum() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Sum(w2Client.getMatrixId(), 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
double sum = 0.0;
for (double x : localArray1) {
sum += x;
}
Assert.assertEquals(result, sum, delta);
}
Aggregations