Search in sources :

Example 1 with ScalarAggrResult

use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.

the class AggrFuncTest method testAmax.

@Test
public void testAmax() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Amax(w2Client.getMatrixId(), 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    double max = Double.MIN_VALUE;
    for (double x : localArray1) {
        if (max < Math.abs(x))
            max = Math.abs(x);
    }
    Assert.assertEquals(result, max, delta);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Example 2 with ScalarAggrResult

use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.

the class SampleNeighbor method merge.

@Override
public GetResult merge(List<PartitionGetResult> partResults) {
    SampleNeighborParam param = (SampleNeighborParam) getParam();
    long[] keys = param.getKeys();
    Int2ObjectArrayMap<PartitionGetResult> partIdToResult = new Int2ObjectArrayMap<>();
    for (PartitionGetResult result : partResults) {
        partIdToResult.put(((SampleNeighborPartResult) result).getPartId(), result);
    }
    for (PartitionGetParam partParam : param.getParams()) {
        SampleNeighborPartParam param0 = (SampleNeighborPartParam) partParam;
        int start = param0.getStartIndex();
        int end = param0.getEndIndex();
        SampleNeighborPartResult result = (SampleNeighborPartResult) partIdToResult.get(param0.getPartKey().getPartitionId());
        int[] indptr = result.getIndptr();
        long[] neighbors = result.getNeighbors();
        int[] sampleTypes = result.getTypes();
        assert indptr.length == (end - start) + 1;
        for (int i = start; i < end; i++) {
            int keyIndex = index.get(keys[i]);
            for (int j = indptr[i - start]; j < indptr[i - start + 1]; j++) {
                long n = neighbors[j];
                if (!index.containsKey(n)) {
                    index.put(n, index.size());
                }
                srcs.add(keyIndex);
                dsts.add(index.get(n));
            }
            if (param.getSampleTypes()) {
                for (int j = indptr[i - start]; j < indptr[i - start + 1]; j++) {
                    types.add(sampleTypes[j]);
                }
            }
        }
    }
    return new ScalarAggrResult(0);
}
Also used : Int2ObjectArrayMap(it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult) PartitionGetParam(com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam) PartitionGetResult(com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult)

Example 3 with ScalarAggrResult

use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.

the class AggrFuncTest method testDot.

@Test
public void testDot() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Dot(w2Client.getMatrixId(), 0, 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    double dot = 0.0;
    for (int i = 0; i < dim; i++) {
        dot += localArray0[i] * localArray1[i];
    }
    Assert.assertEquals(result, dot, delta);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Example 4 with ScalarAggrResult

use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.

the class AggrFuncTest method testNnz.

@Test
public void testNnz() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Nnz(w2Client.getMatrixId(), 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    int count = 0;
    for (double x : localArray1) {
        if (Math.abs(x - 0.0) > delta)
            count++;
    }
    Assert.assertEquals((int) result, count);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Example 5 with ScalarAggrResult

use of com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult in project angel by Tencent.

the class AggrFuncTest method testSum.

@Test
public void testSum() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Sum(w2Client.getMatrixId(), 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    double sum = 0.0;
    for (double x : localArray1) {
        sum += x;
    }
    Assert.assertEquals(result, sum, delta);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Aggregations

ScalarAggrResult (com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)10 GetFunc (com.tencent.angel.ml.matrix.psf.get.base.GetFunc)9 PartitionGetParam (com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam)1 PartitionGetResult (com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult)1 Int2ObjectArrayMap (it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap)1