Search in sources :

Example 1 with DistributedCbowDotMessage

use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedCbowDotMessage in project nd4j by deeplearning4j.

the class CbowTrainer method startTraining.

@Override
public void startTraining(CbowRequestMessage message) {
    CbowChain chain = new CbowChain(message);
    chain.addElement(message);
    chains.put(RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId()), chain);
    int[] row_syn1 = message.getSyn1rows();
    if (message.getNegSamples() > 0) {
        int rows = storage.getArray(WordVectorStorage.SYN_0).rows();
        int[] tempArray = new int[message.getNegSamples() + 1];
        tempArray[0] = message.getW1();
        for (int e = 1; e < message.getNegSamples() + 1; e++) {
            while (true) {
                int rnd = RandomUtils.nextInt(0, rows);
                if (rnd != message.getW1()) {
                    tempArray[e] = rnd;
                    break;
                }
            }
        }
        row_syn1 = ArrayUtils.addAll(row_syn1, tempArray);
        message.setNegatives(tempArray);
    }
    if (message.getSyn0rows() == null || message.getSyn0rows().length < 1)
        throw new RuntimeException("Empty syn0rows!");
    DistributedCbowDotMessage dcdm = new DistributedCbowDotMessage(message.getTaskId(), message.getSyn0rows(), row_syn1, message.getW1(), message.getCodes(), message.getCodes().length > 0, (short) message.getNegSamples(), (float) message.getAlpha());
    dcdm.setTargetId((short) -1);
    dcdm.setOriginatorId(message.getOriginatorId());
    if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
        transport.putMessage(dcdm);
    } else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
        transport.sendMessage(dcdm);
    }
}
Also used : DistributedCbowDotMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedCbowDotMessage) CbowChain(org.nd4j.parameterserver.distributed.training.chains.CbowChain)

Aggregations

DistributedCbowDotMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedCbowDotMessage)1 CbowChain (org.nd4j.parameterserver.distributed.training.chains.CbowChain)1