Search in sources :

Example 1 with FrameCompleteMessage

use of org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage in project nd4j by deeplearning4j.

the class SkipGramTrainer method finishTraining.

@Override
public void finishTraining(long originatorId, long taskId) {
    RequestDescriptor chainDesc = RequestDescriptor.createDescriptor(originatorId, taskId);
    SkipGramChain chain = chains.get(chainDesc);
    if (chain == null)
        throw new RuntimeException("Unable to find chain for specified taskId: [" + taskId + "]");
    SkipGramRequestMessage sgrm = chain.getRequestMessage();
    double alpha = sgrm.getAlpha();
    // log.info("Executing SkipGram round on shard_{}; taskId: {}", transport.getShardIndex(), taskId);
    // TODO: We DON'T want this code being here
    // TODO: We DO want this algorithm to be native
    INDArray expTable = storage.getArray(WordVectorStorage.EXP_TABLE);
    INDArray dots = chain.getDotAggregation().getAccumulatedResult();
    INDArray syn0 = storage.getArray(WordVectorStorage.SYN_0);
    INDArray syn1 = storage.getArray(WordVectorStorage.SYN_1);
    INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
    INDArray neu1e = Nd4j.create(syn0.columns());
    int e = 0;
    boolean updated = false;
    // apply optional SkipGram HS gradients
    if (sgrm.getCodes().length > 0) {
        for (; e < sgrm.getCodes().length; e++) {
            float dot = dots.getFloat(e);
            if (dot < -HS_MAX_EXP || dot >= HS_MAX_EXP) {
                continue;
            }
            int idx = (int) ((dot + HS_MAX_EXP) * ((float) expTable.length() / HS_MAX_EXP / 2.0));
            if (idx >= expTable.length() || idx < 0) {
                continue;
            }
            int code = chain.getRequestMessage().getCodes()[e];
            double f = expTable.getFloat(idx);
            double g = (1 - code - f) * alpha;
            updated = true;
            Nd4j.getBlasWrapper().axpy(new Double(g), syn1.getRow(sgrm.getPoints()[e]), neu1e);
            Nd4j.getBlasWrapper().axpy(new Double(g), syn0.getRow(sgrm.getW2()), syn1.getRow(sgrm.getPoints()[e]));
        }
    }
    // apply optional NegSample gradients
    if (sgrm.getNegSamples() > 0) {
        // here we assume that we already
        int cnt = 0;
        for (; e < sgrm.getNegSamples() + 1; e++, cnt++) {
            float dot = dots.getFloat(e);
            float code = cnt == 0 ? 1.0f : 0.0f;
            double g = 0.0f;
            if (dot > HS_MAX_EXP)
                g = (code - 1) * alpha;
            else if (dot < -HS_MAX_EXP)
                g = (code - 0) * alpha;
            else {
                int idx = (int) ((dot + HS_MAX_EXP) * (expTable.length() / HS_MAX_EXP / 2.0));
                if (idx >= expTable.length() || idx < 0)
                    continue;
                g = (code - expTable.getDouble(idx)) * alpha;
            }
            updated = true;
            Nd4j.getBlasWrapper().axpy(new Double(g), syn1Neg.getRow(sgrm.getNegatives()[cnt]), neu1e);
            Nd4j.getBlasWrapper().axpy(new Double(g), syn0.getRow(sgrm.getW2()), syn1Neg.getRow(sgrm.getNegatives()[cnt]));
        }
    }
    if (updated)
        Nd4j.getBlasWrapper().axpy(new Double(1.0), neu1e, syn0.getRow(sgrm.getW2()));
    // we send back confirmation message only from Shard which received this message
    RequestDescriptor descriptor = RequestDescriptor.createDescriptor(chain.getOriginatorId(), chain.getFrameId());
    if (completionHandler.isTrackingFrame(descriptor)) {
        completionHandler.notifyFrame(chain.getOriginatorId(), chain.getFrameId(), chain.getTaskId());
        if (completionHandler.isCompleted(descriptor)) {
            FrameCompletionHandler.FrameDescriptor frameDescriptor = completionHandler.getCompletedFrameInfo(descriptor);
            // TODO: there is possible race condition here
            if (frameDescriptor != null) {
                FrameCompleteMessage fcm = new FrameCompleteMessage(chain.getFrameId());
                fcm.setOriginatorId(frameDescriptor.getFrameOriginatorId());
                transport.sendMessage(fcm);
            } else {
                log.warn("Frame double spending detected");
            }
        }
    } else {
        log.info("sI_{} isn't tracking this frame: Originator: {}, frameId: {}, taskId: {}", transport.getShardIndex(), chain.getOriginatorId(), chain.getFrameId(), taskId);
    }
    if (cntRounds.incrementAndGet() % 100000 == 0)
        log.info("{} training rounds finished...", cntRounds.get());
    // don't forget to remove chain, it'll become a leak otherwise
    chains.remove(chainDesc);
}
Also used : SkipGramChain(org.nd4j.parameterserver.distributed.training.chains.SkipGramChain) FrameCompletionHandler(org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler) FrameCompleteMessage(org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage) INDArray(org.nd4j.linalg.api.ndarray.INDArray) RequestDescriptor(org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor) SkipGramRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)

Example 2 with FrameCompleteMessage

use of org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage in project nd4j by deeplearning4j.

the class CbowTrainer method finishTraining.

@Override
public void finishTraining(long originatorId, long taskId) {
    RequestDescriptor chainDesc = RequestDescriptor.createDescriptor(originatorId, taskId);
    CbowChain chain = chains.get(chainDesc);
    if (chain == null)
        throw new RuntimeException("Unable to find chain for specified taskId: [" + taskId + "]");
    CbowRequestMessage cbr = chain.getCbowRequest();
    double alpha = cbr.getAlpha();
    // log.info("Executing SkipGram round on shard_{}; taskId: {}", transport.getShardIndex(), taskId);
    // TODO: We DON'T want this code being here
    // TODO: We DO want this algorithm to be native
    INDArray expTable = storage.getArray(WordVectorStorage.EXP_TABLE);
    INDArray dots = chain.getDotAggregation().getAccumulatedResult();
    INDArray syn0 = storage.getArray(WordVectorStorage.SYN_0);
    INDArray syn1 = storage.getArray(WordVectorStorage.SYN_1);
    INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
    INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, cbr.getSyn0rows(), 'c');
    INDArray neue = words.mean(0);
    INDArray neu1e = Nd4j.create(syn0.columns());
    int e = 0;
    boolean updated = false;
    // probably applying HS part
    if (cbr.getCodes().length > 0) {
        for (; e < cbr.getCodes().length; e++) {
            float dot = dots.getFloat(e);
            if (dot < -HS_MAX_EXP || dot >= HS_MAX_EXP) {
                continue;
            }
            int idx = (int) ((dot + HS_MAX_EXP) * ((float) expTable.length() / HS_MAX_EXP / 2.0));
            if (idx >= expTable.length() || idx < 0) {
                continue;
            }
            int code = cbr.getCodes()[e];
            double f = expTable.getFloat(idx);
            double g = (1 - code - f) * alpha;
            updated = true;
            Nd4j.getBlasWrapper().axpy(new Double(g), syn1.getRow(cbr.getSyn1rows()[e]), neu1e);
            Nd4j.getBlasWrapper().axpy(new Double(g), neue, syn1.getRow(cbr.getSyn1rows()[e]));
        }
    }
    if (cbr.getNegSamples() > 0) {
        int cnt = 0;
        for (; e < cbr.getNegSamples() + 1; e++, cnt++) {
            float dot = dots.getFloat(e);
            float code = cnt == 0 ? 1.0f : 0.0f;
            double g = 0.0f;
            if (dot > HS_MAX_EXP)
                g = (code - 1) * alpha;
            else if (dot < -HS_MAX_EXP)
                g = (code - 0) * alpha;
            else {
                int idx = (int) ((dot + HS_MAX_EXP) * (expTable.length() / HS_MAX_EXP / 2.0));
                if (idx >= expTable.length() || idx < 0)
                    continue;
                g = (code - expTable.getDouble(idx)) * alpha;
            }
            updated = true;
            Nd4j.getBlasWrapper().axpy(new Double(g), syn1Neg.getRow(cbr.getNegatives()[cnt]), neu1e);
            Nd4j.getBlasWrapper().axpy(new Double(g), neue, syn1Neg.getRow(cbr.getNegatives()[cnt]));
        }
    }
    if (updated)
        for (int i = 0; i < cbr.getSyn0rows().length; i++) {
            Nd4j.getBlasWrapper().axpy(new Double(1.0), neu1e, syn0.getRow(cbr.getSyn0rows()[i]));
        }
    // we send back confirmation message only from Shard which received this message
    RequestDescriptor descriptor = RequestDescriptor.createDescriptor(chain.getOriginatorId(), chain.getFrameId());
    if (completionHandler.isTrackingFrame(descriptor)) {
        completionHandler.notifyFrame(chain.getOriginatorId(), chain.getFrameId(), chain.getTaskId());
        if (completionHandler.isCompleted(descriptor)) {
            FrameCompletionHandler.FrameDescriptor frameDescriptor = completionHandler.getCompletedFrameInfo(descriptor);
            // TODO: there is possible race condition here
            if (frameDescriptor != null) {
                FrameCompleteMessage fcm = new FrameCompleteMessage(chain.getFrameId());
                fcm.setOriginatorId(frameDescriptor.getFrameOriginatorId());
                transport.sendMessage(fcm);
            } else {
                log.warn("Frame double spending detected");
            }
        }
    } else {
    // log.info("sI_{} isn't tracking this frame: Originator: {}, frameId: {}, taskId: {}", transport.getShardIndex(), chain.getOriginatorId(), chain.getFrameId(), taskId );
    }
    if (cntRounds.incrementAndGet() % 100000 == 0)
        log.info("{} training rounds finished...", cntRounds.get());
    // don't forget to remove chain, it'll become a leak otherwise
    chains.remove(chainDesc);
}
Also used : FrameCompletionHandler(org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler) FrameCompleteMessage(org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CbowChain(org.nd4j.parameterserver.distributed.training.chains.CbowChain) RequestDescriptor(org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 FrameCompletionHandler (org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler)2 RequestDescriptor (org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor)2 FrameCompleteMessage (org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage)2 CbowRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)1 SkipGramRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)1 CbowChain (org.nd4j.parameterserver.distributed.training.chains.CbowChain)1 SkipGramChain (org.nd4j.parameterserver.distributed.training.chains.SkipGramChain)1