Search in sources :

Example 31 with UpdateFunc

use of com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc in project angel by Tencent.

the class UpdateFuncTest method testFill.

@Test
public void testFill() throws Exception {
    UpdateFunc func = new Fill(w2Client.getMatrixId(), 3, -1.0);
    w2Client.update(func).get();
    double[] result = pull(w2Client, 3);
    assert (result.length == dim);
    for (int i = 0; i < result.length; i++) {
        Assert.assertEquals(result[i], -1.0, delta);
    }
}
Also used : CompressUpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc) UpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc) Test(org.junit.Test)

Example 32 with UpdateFunc

use of com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc in project angel by Tencent.

the class WorkerPool method update.

/**
 * Update a partition use PSF
 * @param seqId rpc request id
 * @param request rpc request
 * @param in serialized rpc request
 * @return serialized rpc response
 */
private ByteBuf update(int seqId, UpdaterRequest request, ByteBuf in) {
    UpdaterResponse response = null;
    ByteBuf buf = ByteBufUtils.newByteBuf(4 + 8, useDirectorBuffer);
    // Get partition and check the partition state
    PartitionKey partKey = request.getPartKey();
    ServerPartition part = context.getMatrixStorageManager().getPart(partKey.getMatrixId(), partKey.getPartitionId());
    if (part == null) {
        String log = "update " + request + " failed. The partition " + partKey + " does not exist";
        LOG.fatal(log);
        response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FATAL, log);
        response.serialize(buf);
        return buf;
    }
    PartitionState state = part.getState();
    if (state != PartitionState.READ_AND_WRITE) {
        String log = "update " + request + " failed. The partition " + partKey + " state is " + state;
        LOG.error(log);
        response = new UpdaterResponse(ResponseType.PARTITION_READ_ONLY, log);
        response.serialize(buf);
        return buf;
    }
    // Get the stored pss for this partition
    PartitionLocation partLoc = null;
    try {
        partLoc = context.getMatrixMetaManager().getPartLocation(request.getPartKey(), disableRouterCache);
    } catch (Throwable x) {
        String log = "update " + request + " failed, get partition location from master failed " + x.getMessage();
        LOG.error(log, x);
        response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FAILED, log);
        response.serialize(buf);
        return buf;
    }
    // Check this ps is the master ps for this location, only master ps can accept the update
    if (!request.isComeFromPs() && !isPartMasterPs(partLoc)) {
        String log = "update " + request + " failed, update to slave ps for partition " + request.getPartKey();
        LOG.error(log);
        response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FAILED, log);
    } else {
        try {
            Class<? extends UpdateFunc> funcClass = (Class<? extends UpdateFunc>) Class.forName(request.getUpdaterFuncClass());
            Constructor<? extends UpdateFunc> constructor = funcClass.getConstructor();
            constructor.setAccessible(true);
            UpdateFunc func = constructor.newInstance();
            func.setPsContext(context);
            // Check the partition state again
            state = part.getState();
            if (state != PartitionState.READ_AND_WRITE) {
                String log = "update " + request + " failed. The partition " + partKey + " state is " + state;
                LOG.error(log);
                response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FAILED, log);
                response.serialize(buf);
                return buf;
            }
            part.update(func, request.getPartParam());
            response = new UpdaterResponse();
            response.setResponseType(ResponseType.SUCCESS);
            if (partLoc.psLocs.size() > 1) {
                // Start to put the update to the slave pss
                context.getPS2PSPusher().put(request, in, partLoc);
            }
        } catch (Throwable e) {
            String log = "update " + request + " failed " + e.getMessage();
            LOG.fatal(log, e);
            response = new UpdaterResponse(ResponseType.SERVER_HANDLE_FATAL, log);
        }
    }
    buf.writeInt(seqId);
    response.serialize(buf);
    return buf;
}
Also used : UpdateFunc(com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc) PartitionKey(com.tencent.angel.PartitionKey) ByteBuf(io.netty.buffer.ByteBuf) PartitionLocation(com.tencent.angel.ml.matrix.PartitionLocation)

Aggregations

UpdateFunc (com.tencent.angel.ml.matrix.psf.update.enhance.UpdateFunc)32 CompressUpdateFunc (com.tencent.angel.ml.matrix.psf.update.enhance.CompressUpdateFunc)31 Test (org.junit.Test)31 Push (com.tencent.angel.ml.matrix.psf.update.Push)2 PartitionKey (com.tencent.angel.PartitionKey)1 PartitionLocation (com.tencent.angel.ml.matrix.PartitionLocation)1 Increment (com.tencent.angel.ml.matrix.psf.update.Increment)1 ByteBuf (io.netty.buffer.ByteBuf)1