Search in sources :

Example 1 with GetFunc

use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.

the class WorkerPool method getSplit.

/**
 * Get from the partition use PSF
 * @param seqId rpc request id
 * @param request request
 * @return serialized rpc response contain the get result
 */
private ByteBuf getSplit(int seqId, GetUDFRequest request) {
    GetUDFResponse response = null;
    try {
        Class<? extends GetFunc> funcClass = (Class<? extends GetFunc>) Class.forName(request.getGetFuncClass());
        Constructor<? extends GetFunc> constructor = funcClass.getConstructor();
        constructor.setAccessible(true);
        GetFunc func = constructor.newInstance();
        func.setPsContext(context);
        PartitionGetResult partResult = func.partitionGet(request.getPartParam());
        response = new GetUDFResponse(partResult);
        response.setResponseType(ResponseType.SUCCESS);
    } catch (Throwable e) {
        LOG.fatal("get udf request " + request + " failed ", e);
        response = new GetUDFResponse();
        response.setDetail("get udf request failed " + e.getMessage());
        response.setResponseType(ResponseType.SERVER_HANDLE_FATAL);
    }
    long startTs = System.currentTimeMillis();
    ByteBuf buf = ByteBufUtils.newByteBuf(4 + response.bufferLen(), useDirectorBuffer);
    buf.writeInt(seqId);
    response.serialize(buf);
    LOG.info("Serialize use time=" + (System.currentTimeMillis() - startTs));
    return buf;
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ByteBuf(io.netty.buffer.ByteBuf) PartitionGetResult(com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult)

Example 2 with GetFunc

use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.

the class AggrFuncTest method testAmax.

@Test
public void testAmax() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Amax(w2Client.getMatrixId(), 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    double max = Double.MIN_VALUE;
    for (double x : localArray1) {
        if (max < Math.abs(x))
            max = Math.abs(x);
    }
    Assert.assertEquals(result, max, delta);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Example 3 with GetFunc

use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.

the class AggrFuncTest method testPull.

@Test
public void testPull() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Pull(w2Client.getMatrixId(), 1);
    double[] result = ((DenseDoubleVector) (((GetRowResult) w2Client.get(func)).getRow())).getValues();
    for (int i = 0; i < dim; i++) {
        Assert.assertEquals(result[i], localArray1[i], delta);
    }
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) Pull(com.tencent.angel.ml.matrix.psf.aggr.Pull) DenseDoubleVector(com.tencent.angel.ml.math.vector.DenseDoubleVector) GetRowResult(com.tencent.angel.ml.matrix.psf.get.single.GetRowResult)

Example 4 with GetFunc

use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.

the class AggrFuncTest method testSum.

@Test
public void testSum() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Sum(w2Client.getMatrixId(), 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    double sum = 0.0;
    for (double x : localArray1) {
        sum += x;
    }
    Assert.assertEquals(result, sum, delta);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Example 5 with GetFunc

use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.

the class AggrFuncTest method testDot.

@Test
public void testDot() throws InvalidParameterException, InterruptedException, ExecutionException {
    GetFunc func = new Dot(w2Client.getMatrixId(), 0, 1);
    double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
    double dot = 0.0;
    for (int i = 0; i < dim; i++) {
        dot += localArray0[i] * localArray1[i];
    }
    Assert.assertEquals(result, dot, delta);
}
Also used : GetFunc(com.tencent.angel.ml.matrix.psf.get.base.GetFunc) ScalarAggrResult(com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)

Aggregations

GetFunc (com.tencent.angel.ml.matrix.psf.get.base.GetFunc)11 ScalarAggrResult (com.tencent.angel.ml.matrix.psf.aggr.enhance.ScalarAggrResult)9 DenseDoubleVector (com.tencent.angel.ml.math.vector.DenseDoubleVector)1 Pull (com.tencent.angel.ml.matrix.psf.aggr.Pull)1 PartitionGetResult (com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult)1 GetRowResult (com.tencent.angel.ml.matrix.psf.get.single.GetRowResult)1 ByteBuf (io.netty.buffer.ByteBuf)1