use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedCbowDotMessage in project nd4j by deeplearning4j.
the class CbowTrainer method startTraining.
@Override
public void startTraining(CbowRequestMessage message) {
CbowChain chain = new CbowChain(message);
chain.addElement(message);
chains.put(RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId()), chain);
int[] row_syn1 = message.getSyn1rows();
if (message.getNegSamples() > 0) {
int rows = storage.getArray(WordVectorStorage.SYN_0).rows();
int[] tempArray = new int[message.getNegSamples() + 1];
tempArray[0] = message.getW1();
for (int e = 1; e < message.getNegSamples() + 1; e++) {
while (true) {
int rnd = RandomUtils.nextInt(0, rows);
if (rnd != message.getW1()) {
tempArray[e] = rnd;
break;
}
}
}
row_syn1 = ArrayUtils.addAll(row_syn1, tempArray);
message.setNegatives(tempArray);
}
if (message.getSyn0rows() == null || message.getSyn0rows().length < 1)
throw new RuntimeException("Empty syn0rows!");
DistributedCbowDotMessage dcdm = new DistributedCbowDotMessage(message.getTaskId(), message.getSyn0rows(), row_syn1, message.getW1(), message.getCodes(), message.getCodes().length > 0, (short) message.getNegSamples(), (float) message.getAlpha());
dcdm.setTargetId((short) -1);
dcdm.setOriginatorId(message.getOriginatorId());
if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
transport.putMessage(dcdm);
} else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
transport.sendMessage(dcdm);
}
}
Aggregations