Search in sources :

Example 1 with CbowTrainer

use of org.nd4j.parameterserver.distributed.training.impl.CbowTrainer in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast3.

/**
 * This test checks for single Shard scenario, when Shard is also a Client
 *
 * @throws Exception
 */
@Test
public void testPerformanceUnicast3() throws Exception {
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
    Transport transport = new RoutedTransport();
    transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
    VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
    parameterServer.setShardIndex((short) 0);
    parameterServer.init(voidConfiguration, transport, new CbowTrainer());
    parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
    final List<Long> times = new ArrayList<>();
    log.info("Starting loop...");
    for (int i = 0; i < 200; i++) {
        Frame<CbowRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
        for (int f = 0; f < 128; f++) {
            frame.stackMessage(getCRM());
        }
        long time1 = System.nanoTime();
        parameterServer.execDistributed(frame);
        long time2 = System.nanoTime();
        times.add(time2 - time1);
        if (i % 50 == 0)
            log.info("{} frames passed...", i);
    }
    Collections.sort(times);
    log.info("p50: {} us", times.get(times.size() / 2) / 1000);
    parameterServer.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) CbowTrainer(org.nd4j.parameterserver.distributed.training.impl.CbowTrainer) Frame(org.nd4j.parameterserver.distributed.messages.Frame) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage) Test(org.junit.Test)

Example 2 with CbowTrainer

use of org.nd4j.parameterserver.distributed.training.impl.CbowTrainer in project nd4j by deeplearning4j.

the class DistributedCbowDotMessage method processMessage.

/**
 * This method calculates dot of gives rows, with averaging applied to rowsA, as required by CBoW
 */
@Override
public void processMessage() {
    // this only picks up new training round
    // log.info("sI_{} Starting CBOW dot...", transport.getShardIndex());
    CbowRequestMessage cbrm = new CbowRequestMessage(rowsA, rowsB, w1, codes, negSamples, alpha, 119);
    if (negSamples > 0) {
        // unfortunately we have to get copy of negSamples here
        int[] negatives = Arrays.copyOfRange(rowsB, codes.length, rowsB.length);
        cbrm.setNegatives(negatives);
    }
    cbrm.setFrameId(-119L);
    cbrm.setTaskId(this.taskId);
    cbrm.setOriginatorId(this.getOriginatorId());
    // FIXME: get rid of THAT
    CbowTrainer cbt = (CbowTrainer) trainer;
    cbt.pickTraining(cbrm);
    // we calculate dot for all involved rows, and first of all we get mean word
    INDArray words = Nd4j.pullRows(storage.getArray(WordVectorStorage.SYN_0), 1, rowsA, 'c');
    INDArray mean = words.mean(0);
    int resultLength = codes.length + (negSamples > 0 ? (negSamples + 1) : 0);
    INDArray result = Nd4j.createUninitialized(resultLength, 1);
    int e = 0;
    for (; e < codes.length; e++) {
        double dot = Nd4j.getBlasWrapper().dot(mean, storage.getArray(WordVectorStorage.SYN_1).getRow(rowsB[e]));
        result.putScalar(e, dot);
    }
    // negSampling round
    for (; e < resultLength; e++) {
        double dot = Nd4j.getBlasWrapper().dot(mean, storage.getArray(WordVectorStorage.SYN_1_NEGATIVE).getRow(rowsB[e]));
        result.putScalar(e, dot);
    }
    if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
        DotAggregation dot = new DotAggregation(taskId, (short) 1, shardIndex, result);
        dot.setTargetId((short) -1);
        dot.setOriginatorId(getOriginatorId());
        transport.putMessage(dot);
    } else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
        // send this message to everyone
        DotAggregation dot = new DotAggregation(taskId, (short) voidConfiguration.getNumberOfShards(), shardIndex, result);
        dot.setTargetId((short) -1);
        dot.setOriginatorId(getOriginatorId());
        transport.sendMessage(dot);
    }
}
Also used : CbowTrainer(org.nd4j.parameterserver.distributed.training.impl.CbowTrainer) DotAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)

Aggregations

CbowRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage)2 CbowTrainer (org.nd4j.parameterserver.distributed.training.impl.CbowTrainer)2 ArrayList (java.util.ArrayList)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1 AtomicLong (java.util.concurrent.atomic.AtomicLong)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1 Frame (org.nd4j.parameterserver.distributed.messages.Frame)1 DotAggregation (org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation)1 MulticastTransport (org.nd4j.parameterserver.distributed.transport.MulticastTransport)1 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)1 Transport (org.nd4j.parameterserver.distributed.transport.Transport)1