use of com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc in project angel by Tencent.
the class UpdateFuncTest method testFill.
@Test
public void testFill() throws Exception {
UpdateFunc func = new Fill(w2Client.getMatrixId(), 3, -1.0);
w2Client.update(func).get();
double[] result = pull(w2Client, 3);
assert (result.length == dim);
for (int i = 0; i < result.length; i++) {
Assert.assertEquals(result[i], -1.0, delta);
}
}
use of com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc in project angel by Tencent.
the class WorkerPool method update.
/**
* Update a partition use PSF
* @param seqId rpc request id
* @param request rpc request
* @param in serialized rpc request
* @return serialized rpc response
*/
private ByteBuf update(int seqId, UpdaterRequest request, ByteBuf in) {
UpdaterResponse response = null;
ByteBuf buf = ByteBufUtils.newByteBuf(4 + 8, useDirectorBuffer);
// Get partition and check the partition state
PartitionKey partKey = request.getPartKey();
ServerPartition part = context.getMatrixStorageManager().getPart(partKey.getMatrixId(), partKey.getPartitionId());
if (part == null) {
String log = "update " + request + " failed. The partition " + partKey + " does not exist";
LOG.fatal(log);
response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FATAL, log);
response.serialize(buf);
return buf;
}
PartitionState state = part.getState();
if (state != PartitionState.READ_AND_WRITE) {
String log = "update " + request + " failed. The partition " + partKey + " state is " + state;
LOG.error(log);
response = new UpdaterResponse(ResponseType.PARTITION_READ_ONLY, log);
response.serialize(buf);
return buf;
}
// Get the stored pss for this partition
PartitionLocation partLoc = null;
try {
partLoc = context.getMatrixMetaManager().getPartLocation(request.getPartKey(), disableRouterCache);
} catch (Throwable x) {
String log = "update " + request + " failed, get partition location from master failed " + x.getMessage();
LOG.error(log, x);
response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FAILED, log);
response.serialize(buf);
return buf;
}
// Check this ps is the master ps for this location, only master ps can accept the update
if (!request.isComeFromPs() && !isPartMasterPs(partLoc)) {
String log = "update " + request + " failed, update to slave ps for partition " + request.getPartKey();
LOG.error(log);
response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FAILED, log);
} else {
try {
Class<? extends UpdateFunc> funcClass = (Class<? extends UpdateFunc>) Class.forName(request.getUpdaterFuncClass());
Constructor<? extends UpdateFunc> constructor = funcClass.getConstructor();
constructor.setAccessible(true);
UpdateFunc func = constructor.newInstance();
func.setPsContext(context);
// Check the partition state again
state = part.getState();
if (state != PartitionState.READ_AND_WRITE) {
String log = "update " + request + " failed. The partition " + partKey + " state is " + state;
LOG.error(log);
response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FAILED, log);
response.serialize(buf);
return buf;
}
part.update(func, request.getPartParam());
response = new UpdaterResponse();
response.setResponseType(ResponseType.SUCCESS);
if (partLoc.psLocs.size() > 1) {
// Start to put the update to the slave pss
context.getPS2PSPusher().put(request, in, partLoc);
}
} catch (Throwable e) {
String log = "update " + request + " failed " + e.getMessage();
LOG.fatal(log, e);
response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FATAL, log);
}
}
buf.writeInt(seqId);
response.serialize(buf);
return buf;
}
Aggregations