Search in sources :

Example 1 with InitializationAggregation

use of org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation in project nd4j by deeplearning4j.

the class ClipboardTest method testPin3.

/**
 * This test checks how clipboard handles singular aggregations
 * @throws Exception
 */
@Test
public void testPin3() throws Exception {
    Clipboard clipboard = new Clipboard();
    Random rng = new Random(12345L);
    Long validId = 123L;
    InitializationAggregation aggregation = new InitializationAggregation(1, 0);
    clipboard.pin(aggregation);
    assertTrue(clipboard.isTracking(0L, aggregation.getTaskId()));
    assertTrue(clipboard.isReady(0L, aggregation.getTaskId()));
}
Also used : InitializationAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) Test(org.junit.Test)

Example 2 with InitializationAggregation

use of org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation in project nd4j by deeplearning4j.

the class DistributedInitializationMessage method processMessage.

/**
 * This method initializes shard storage with given data
 */
@Override
public void processMessage() {
    // protection check, we definitely don't want double spending here
    INDArray syn0 = storage.getArray(WordVectorStorage.SYN_0);
    INDArray syn1 = storage.getArray(WordVectorStorage.SYN_1);
    INDArray syn1Neg = storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
    INDArray expTable = storage.getArray(WordVectorStorage.EXP_TABLE);
    if (syn0 == null) {
        log.info("sI_{} is starting initialization...", transport.getShardIndex());
        // we initialize only syn0/syn1/syn1neg and expTable
        // negTable will be initalized at driver level and will be shared via message
        Nd4j.getRandom().setSeed(seed * (shardIndex + 1));
        if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
            // each shard has full own copy
            columnsPerShard = vectorLength;
        } else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
            // each shard will have only part of the data
            if (voidConfiguration.getNumberOfShards() - 1 == shardIndex) {
                int modulo = vectorLength % voidConfiguration.getNumberOfShards();
                if (modulo != 0) {
                    columnsPerShard += modulo;
                    log.info("Got inequal split. using higher number of elements: {}", columnsPerShard);
                }
            }
        }
        int[] shardShape = new int[] { numWords, columnsPerShard };
        syn0 = Nd4j.rand(shardShape, 'c').subi(0.5).divi(vectorLength);
        if (useHs)
            syn1 = Nd4j.create(shardShape, 'c');
        if (useNeg)
            syn1Neg = Nd4j.create(shardShape, 'c');
        // we handle full exp table here
        expTable = initExpTable(100000);
        storage.setArray(WordVectorStorage.SYN_0, syn0);
        if (useHs)
            storage.setArray(WordVectorStorage.SYN_1, syn1);
        if (useNeg)
            storage.setArray(WordVectorStorage.SYN_1_NEGATIVE, syn1Neg);
        storage.setArray(WordVectorStorage.EXP_TABLE, expTable);
        InitializationAggregation ia = new InitializationAggregation((short) voidConfiguration.getNumberOfShards(), transport.getShardIndex());
        ia.setOriginatorId(this.originatorId);
        transport.sendMessage(ia);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) InitializationAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation)

Example 3 with InitializationAggregation

use of org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation in project nd4j by deeplearning4j.

the class InitializationRequestMessage method processMessage.

@Override
public void processMessage() {
    DistributedInitializationMessage dim = new DistributedInitializationMessage(vectorLength, numWords, seed, useHs, useNeg, columnsPerShard);
    InitializationAggregation aggregation = new InitializationAggregation((short) voidConfiguration.getNumberOfShards(), transport.getShardIndex());
    aggregation.setOriginatorId(this.originatorId);
    clipboard.pin(aggregation);
    dim.setOriginatorId(this.originatorId);
    dim.extractContext(this);
    dim.processMessage();
    if (voidConfiguration.getNumberOfShards() > 1)
        transport.sendMessageToAllShards(dim);
}
Also used : DistributedInitializationMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedInitializationMessage) InitializationAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation)

Aggregations

InitializationAggregation (org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation)3 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 Clipboard (org.nd4j.parameterserver.distributed.logic.completion.Clipboard)1 DistributedInitializationMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedInitializationMessage)1