use of org.nd4j.parameterserver.distributed.training.impl.CbowTrainer in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceUnicast3.
/**
* This test checks for single Shard scenario, when Shard is also a Client
*
* @throws Exception
*/
@Test
public void testPerformanceUnicast3() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
Transport transport = new RoutedTransport();
transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
parameterServer.setShardIndex((short) 0);
parameterServer.init(voidConfiguration, transport, new CbowTrainer());
parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
final List<Long> times = new ArrayList<>();
log.info("Starting loop...");
for (int i = 0; i < 200; i++) {
Frame<CbowRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
for (int f = 0; f < 128; f++) {
frame.stackMessage(getCRM());
}
long time1 = System.nanoTime();
parameterServer.execDistributed(frame);
long time2 = System.nanoTime();
times.add(time2 - time1);
if (i % 50 == 0)
log.info("{} frames passed...", i);
}
Collections.sort(times);
log.info("p50: {} us", times.get(times.size() / 2) / 1000);
parameterServer.shutdown();
}
use of org.nd4j.parameterserver.distributed.training.impl.CbowTrainer in project nd4j by deeplearning4j.
the class DistributedCbowDotMessage method processMessage.
/**
* This method calculates dot of gives rows, with averaging applied to rowsA, as required by CBoW
*/
@Override
public void processMessage() {
// this only picks up new training round
// log.info("sI_{} Starting CBOW dot...", transport.getShardIndex());
CbowRequestMessage cbrm = new CbowRequestMessage(rowsA, rowsB, w1, codes, negSamples, alpha, 119);
if (negSamples > 0) {
// unfortunately we have to get copy of negSamples here
int[] negatives = Arrays.copyOfRange(rowsB, codes.length, rowsB.length);
cbrm.setNegatives(negatives);
}
cbrm.setFrameId(-119L);
cbrm.setTaskId(this.taskId);
cbrm.setOriginatorId(this.getOriginatorId());
// FIXME: get rid of THAT
CbowTrainer cbt = (CbowTrainer) trainer;
cbt.pickTraining(cbrm);
// we calculate dot for all involved rows, and first of all we get mean word
INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, rowsA, 'c');
INDArray mean = words.mean(0);
int resultLength = codes.length + (negSamples > 0 ? (negSamples + 1) : 0);
INDArray result = Nd4j.createUninitialized(resultLength, 1);
int e = 0;
for (; e < codes.length; e++) {
double dot = Nd4j.getBlasWrapper().dot(mean, storage.getArray(WordVectorStorage.SYN_1).getRow(rowsB[e]));
result.putScalar(e, dot);
}
// negSampling round
for (; e < resultLength; e++) {
double dot = Nd4j.getBlasWrapper().dot(mean, storage.getArray(WordVectorStorage.SYN_1_NEGATIVE).getRow(rowsB[e]));
result.putScalar(e, dot);
}
if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
DotAggregation dot = new DotAggregation(taskId, (short) 1, shardIndex, result);
dot.setTargetId((short) -1);
dot.setOriginatorId(getOriginatorId());
transport.putMessage(dot);
} else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
// send this message to everyone
DotAggregation dot = new DotAggregation(taskId, (short) voidConfiguration.getNumberOfShards(), shardIndex, result);
dot.setTargetId((short) -1);
dot.setOriginatorId(getOriginatorId());
transport.sendMessage(dot);
}
}
Aggregations