use of org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation in project nd4j by deeplearning4j.
the class ClipboardTest method testPin3.
/**
* This test checks how clipboard handles singular aggregations
* @throws Exception
*/
@Test
public void testPin3() throws Exception {
Clipboard clipboard = new Clipboard();
Random rng = new Random(12345L);
Long validId = 123L;
InitializationAggregation aggregation = new InitializationAggregation(1, 0);
clipboard.pin(aggregation);
assertTrue(clipboard.isTracking(0L, aggregation.getTaskId()));
assertTrue(clipboard.isReady(0L, aggregation.getTaskId()));
}
use of org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation in project nd4j by deeplearning4j.
the class DistributedInitializationMessage method processMessage.
/**
* This method initializes shard storage with given data
*/
@Override
public void processMessage() {
// protection check, we definitely don't want double spending here
INDArray syn0 = storage.getArray(WordVectorStorage.SYN_0);
INDArray syn1 = storage.getArray(WordVectorStorage.SYN_1);
INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
INDArray expTable = storage.getArray(WordVectorStorage.EXP_TABLE);
if (syn0 == null) {
log.info("sI_{} is starting initialization...", transport.getShardIndex());
// we initialize only syn0/syn1/syn1neg and expTable
// negTable will be initalized at driver level and will be shared via message
Nd4j.getRandom().setSeed(seed * (shardIndex + 1));
if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
// each shard has full own copy
columnsPerShard = vectorLength;
} else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
// each shard will have only part of the data
if (voidConfiguration.getNumberOfShards() - 1 == shardIndex) {
int modulo = vectorLength % voidConfiguration.getNumberOfShards();
if (modulo != 0) {
columnsPerShard += modulo;
log.info("Got inequal split. using higher number of elements: {}", columnsPerShard);
}
}
}
int[] shardShape = new int[] { numWords, columnsPerShard };
syn0 = Nd4j.rand(shardShape, 'c').subi(0.5).divi(vectorLength);
if (useHs)
syn1 = Nd4j.create(shardShape, 'c');
if (useNeg)
syn1Neg = Nd4j.create(shardShape, 'c');
// we handle full exp table here
expTable = initExpTable(100000);
storage.setArray(WordVectorStorage.SYN_0, syn0);
if (useHs)
storage.setArray(WordVectorStorage.SYN_1, syn1);
if (useNeg)
storage.setArray(WordVectorStorage.SYN_1_NEGATIVE, syn1Neg);
storage.setArray(WordVectorStorage.EXP_TABLE, expTable);
InitializationAggregation ia = new InitializationAggregation((short) voidConfiguration.getNumberOfShards(), transport.getShardIndex());
ia.setOriginatorId(this.originatorId);
transport.sendMessage(ia);
}
}
use of org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation in project nd4j by deeplearning4j.
the class InitializationRequestMessage method processMessage.
@Override
public void processMessage() {
DistributedInitializationMessage dim = new DistributedInitializationMessage(vectorLength, numWords, seed, useHs, useNeg, columnsPerShard);
InitializationAggregation aggregation = new InitializationAggregation((short) voidConfiguration.getNumberOfShards(), transport.getShardIndex());
aggregation.setOriginatorId(this.originatorId);
clipboard.pin(aggregation);
dim.setOriginatorId(this.originatorId);
dim.extractContext(this);
dim.processMessage();
if (voidConfiguration.getNumberOfShards() > 1)
transport.sendMessageToAllShards(dim);
}
Aggregations