use of org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor in project nd4j by deeplearning4j.
the class SkipGramTrainer method finishTraining.
@Override
public void finishTraining(long originatorId, long taskId) {
RequestDescriptor chainDesc = RequestDescriptor.createDescriptor(originatorId, taskId);
SkipGramChain chain = chains.get(chainDesc);
if (chain == null)
throw new RuntimeException("Unable to find chain for specified taskId: [" + taskId + "]");
SkipGramRequestMessage sgrm = chain.getRequestMessage();
double alpha = sgrm.getAlpha();
// log.info("Executing SkipGram round on shard_{}; taskId: {}", transport.getShardIndex(), taskId);
// TODO: We DON'T want this code being here
// TODO: We DO want this algorithm to be native
INDArray expTable = storage.getArray(WordVectorStorage.EXP_TABLE);
INDArray dots = chain.getDotAggregation().getAccumulatedResult();
INDArray syn0 = storage.getArray(WordVectorStorage.SYN_0);
INDArray syn1 = storage.getArray(WordVectorStorage.SYN_1);
INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
INDArray neu1e = Nd4j.create(syn0.columns());
int e = 0;
boolean updated = false;
// apply optional SkipGram HS gradients
if (sgrm.getCodes().length > 0) {
for (; e < sgrm.getCodes().length; e++) {
float dot = dots.getFloat(e);
if (dot < -HS_MAX_EXP || dot >= HS_MAX_EXP) {
continue;
}
int idx = (int) ((dot + HS_MAX_EXP) * ((float) expTable.length() / HS_MAX_EXP / 2.0));
if (idx >= expTable.length() || idx < 0) {
continue;
}
int code = chain.getRequestMessage().getCodes()[e];
double f = expTable.getFloat(idx);
double g = (1 - code - f) * alpha;
updated = true;
Nd4j.getBlasWrapper().axpy(new Double(g), syn1.getRow(sgrm.getPoints()[e]), neu1e);
Nd4j.getBlasWrapper().axpy(new Double(g), syn0.getRow(sgrm.getW2()), syn1.getRow(sgrm.getPoints()[e]));
}
}
// apply optional NegSample gradients
if (sgrm.getNegSamples() > 0) {
// here we assume that we already
int cnt = 0;
for (; e < sgrm.getNegSamples() + 1; e++, cnt++) {
float dot = dots.getFloat(e);
float code = cnt == 0 ? 1.0f : 0.0f;
double g = 0.0f;
if (dot > HS_MAX_EXP)
g = (code - 1) * alpha;
else if (dot < -HS_MAX_EXP)
g = (code - 0) * alpha;
else {
int idx = (int) ((dot + HS_MAX_EXP) * (expTable.length() / HS_MAX_EXP / 2.0));
if (idx >= expTable.length() || idx < 0)
continue;
g = (code - expTable.getDouble(idx)) * alpha;
}
updated = true;
Nd4j.getBlasWrapper().axpy(new Double(g), syn1Neg.getRow(sgrm.getNegatives()[cnt]), neu1e);
Nd4j.getBlasWrapper().axpy(new Double(g), syn0.getRow(sgrm.getW2()), syn1Neg.getRow(sgrm.getNegatives()[cnt]));
}
}
if (updated)
Nd4j.getBlasWrapper().axpy(new Double(1.0), neu1e, syn0.getRow(sgrm.getW2()));
// we send back confirmation message only from Shard which received this message
RequestDescriptor descriptor = RequestDescriptor.createDescriptor(chain.getOriginatorId(), chain.getFrameId());
if (completionHandler.isTrackingFrame(descriptor)) {
completionHandler.notifyFrame(chain.getOriginatorId(), chain.getFrameId(), chain.getTaskId());
if (completionHandler.isCompleted(descriptor)) {
FrameCompletionHandler.FrameDescriptor frameDescriptor = completionHandler.getCompletedFrameInfo(descriptor);
// TODO: there is possible race condition here
if (frameDescriptor != null) {
FrameCompleteMessage fcm = new FrameCompleteMessage(chain.getFrameId());
fcm.setOriginatorId(frameDescriptor.getFrameOriginatorId());
transport.sendMessage(fcm);
} else {
log.warn("Frame double spending detected");
}
}
} else {
log.info("sI_{} isn't tracking this frame: Originator: {}, frameId: {}, taskId: {}", transport.getShardIndex(), chain.getOriginatorId(), chain.getFrameId(), taskId);
}
if (cntRounds.incrementAndGet() % 100000 == 0)
log.info("{} training rounds finished...", cntRounds.get());
// don't forget to remove chain, it'll become a leak otherwise
chains.remove(chainDesc);
}
use of org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor in project nd4j by deeplearning4j.
the class SkipGramTrainer method pickTraining.
/**
* This method will be called from non-initialized Shard context
* @param message
*/
@Override
public void pickTraining(@NonNull SkipGramRequestMessage message) {
RequestDescriptor descriptor = RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId());
if (!chains.containsKey(descriptor)) {
SkipGramChain chain = new SkipGramChain(message);
// log.info("sI_{} Picking chain: originator: {}; taskId: {}", transport.getShardIndex(), message.getOriginatorId(), message.getTaskId());
chains.put(descriptor, chain);
}
}
use of org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor in project nd4j by deeplearning4j.
the class CbowTrainer method finishTraining.
@Override
public void finishTraining(long originatorId, long taskId) {
RequestDescriptor chainDesc = RequestDescriptor.createDescriptor(originatorId, taskId);
CbowChain chain = chains.get(chainDesc);
if (chain == null)
throw new RuntimeException("Unable to find chain for specified taskId: [" + taskId + "]");
CbowRequestMessage cbr = chain.getCbowRequest();
double alpha = cbr.getAlpha();
// log.info("Executing SkipGram round on shard_{}; taskId: {}", transport.getShardIndex(), taskId);
// TODO: We DON'T want this code being here
// TODO: We DO want this algorithm to be native
INDArray expTable = storage.getArray(WordVectorStorage.EXP_TABLE);
INDArray dots = chain.getDotAggregation().getAccumulatedResult();
INDArray syn0 = storage.getArray(WordVectorStorage.SYN_0);
INDArray syn1 = storage.getArray(WordVectorStorage.SYN_1);
INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, cbr.getSyn0rows(), 'c');
INDArray neue = words.mean(0);
INDArray neu1e = Nd4j.create(syn0.columns());
int e = 0;
boolean updated = false;
// probably applying HS part
if (cbr.getCodes().length > 0) {
for (; e < cbr.getCodes().length; e++) {
float dot = dots.getFloat(e);
if (dot < -HS_MAX_EXP || dot >= HS_MAX_EXP) {
continue;
}
int idx = (int) ((dot + HS_MAX_EXP) * ((float) expTable.length() / HS_MAX_EXP / 2.0));
if (idx >= expTable.length() || idx < 0) {
continue;
}
int code = cbr.getCodes()[e];
double f = expTable.getFloat(idx);
double g = (1 - code - f) * alpha;
updated = true;
Nd4j.getBlasWrapper().axpy(new Double(g), syn1.getRow(cbr.getSyn1rows()[e]), neu1e);
Nd4j.getBlasWrapper().axpy(new Double(g), neue, syn1.getRow(cbr.getSyn1rows()[e]));
}
}
if (cbr.getNegSamples() > 0) {
int cnt = 0;
for (; e < cbr.getNegSamples() + 1; e++, cnt++) {
float dot = dots.getFloat(e);
float code = cnt == 0 ? 1.0f : 0.0f;
double g = 0.0f;
if (dot > HS_MAX_EXP)
g = (code - 1) * alpha;
else if (dot < -HS_MAX_EXP)
g = (code - 0) * alpha;
else {
int idx = (int) ((dot + HS_MAX_EXP) * (expTable.length() / HS_MAX_EXP / 2.0));
if (idx >= expTable.length() || idx < 0)
continue;
g = (code - expTable.getDouble(idx)) * alpha;
}
updated = true;
Nd4j.getBlasWrapper().axpy(new Double(g), syn1Neg.getRow(cbr.getNegatives()[cnt]), neu1e);
Nd4j.getBlasWrapper().axpy(new Double(g), neue, syn1Neg.getRow(cbr.getNegatives()[cnt]));
}
}
if (updated)
for (int i = 0; i < cbr.getSyn0rows().length; i++) {
Nd4j.getBlasWrapper().axpy(new Double(1.0), neu1e, syn0.getRow(cbr.getSyn0rows()[i]));
}
// we send back confirmation message only from Shard which received this message
RequestDescriptor descriptor = RequestDescriptor.createDescriptor(chain.getOriginatorId(), chain.getFrameId());
if (completionHandler.isTrackingFrame(descriptor)) {
completionHandler.notifyFrame(chain.getOriginatorId(), chain.getFrameId(), chain.getTaskId());
if (completionHandler.isCompleted(descriptor)) {
FrameCompletionHandler.FrameDescriptor frameDescriptor = completionHandler.getCompletedFrameInfo(descriptor);
// TODO: there is possible race condition here
if (frameDescriptor != null) {
FrameCompleteMessage fcm = new FrameCompleteMessage(chain.getFrameId());
fcm.setOriginatorId(frameDescriptor.getFrameOriginatorId());
transport.sendMessage(fcm);
} else {
log.warn("Frame double spending detected");
}
}
} else {
// log.info("sI_{} isn't tracking this frame: Originator: {}, frameId: {}, taskId: {}", transport.getShardIndex(), chain.getOriginatorId(), chain.getFrameId(), taskId );
}
if (cntRounds.incrementAndGet() % 100000 == 0)
log.info("{} training rounds finished...", cntRounds.get());
// don't forget to remove chain, it'll become a leak otherwise
chains.remove(chainDesc);
}
use of org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor in project nd4j by deeplearning4j.
the class CbowTrainer method pickTraining.
@Override
public void pickTraining(CbowRequestMessage message) {
RequestDescriptor descriptor = RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId());
if (!chains.containsKey(descriptor)) {
CbowChain chain = new CbowChain(message);
chain.addElement(message);
chains.put(descriptor, chain);
}
}
Aggregations