use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.
the class WorkerPool method getSplit.
/**
* Get from the partition use PSF
* @param seqId rpc request id
* @param request request
* @return serialized rpc response contain the get result
*/
private ByteBuf getSplit(int seqId, GetUDFRequest request) {
GetUDFResponse response = null;
try {
Class<? extends GetFunc> funcClass = (Class<? extends GetFunc>) Class.forName(request.getGetFuncClass());
Constructor<? extends GetFunc> constructor = funcClass.getConstructor();
constructor.setAccessible(true);
GetFunc func = constructor.newInstance();
func.setPsContext(context);
PartitionGetResult partResult = func.partitionGet(request.getPartParam());
response = new GetUDFResponse(partResult);
response.setResponseType(ResponseType.SUCCESS);
} catch (Throwable e) {
LOG.fatal("get udf request " + request + " failed ", e);
response = new GetUDFResponse();
response.setDetail("get udf request failed " + e.getMessage());
response.setResponseType(ResponseType.SERVER_HANDLE_FATAL);
}
long startTs = System.currentTimeMillis();
ByteBuf buf = ByteBufUtils.newByteBuf(4 + response.bufferLen(), useDirectorBuffer);
buf.writeInt(seqId);
response.serialize(buf);
LOG.info("Serialize use time=" + (System.currentTimeMillis() - startTs));
return buf;
}
use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.
the class AggrFuncTest method testAmax.
@Test
public void testAmax() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Amax(w2Client.getMatrixId(), 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
double max = Double.MIN_VALUE;
for (double x : localArray1) {
if (max < Math.abs(x))
max = Math.abs(x);
}
Assert.assertEquals(result, max, delta);
}
use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.
the class AggrFuncTest method testPull.
@Test
public void testPull() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Pull(w2Client.getMatrixId(), 1);
double[] result = ((DenseDoubleVector) (((GetRowResult) w2Client.get(func)).getRow())).getValues();
for (int i = 0; i < dim; i++) {
Assert.assertEquals(result[i], localArray1[i], delta);
}
}
use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.
the class AggrFuncTest method testSum.
@Test
public void testSum() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Sum(w2Client.getMatrixId(), 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
double sum = 0.0;
for (double x : localArray1) {
sum += x;
}
Assert.assertEquals(result, sum, delta);
}
use of com.tencent.angel.ml.matrix.psf.get.base.GetFunc in project angel by Tencent.
the class AggrFuncTest method testDot.
@Test
public void testDot() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Dot(w2Client.getMatrixId(), 0, 1);
double result = ((ScalarAggrResult) w2Client.get(func)).getResult();
double dot = 0.0;
for (int i = 0; i < dim; i++) {
dot += localArray0[i] * localArray1[i];
}
Assert.assertEquals(result, dot, delta);
}
Aggregations