Search in sources :

Example 6 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast1.

/**
 * This is one of the MOST IMPORTANT tests
 */
@Test
public void testPerformanceUnicast1() {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 1; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
    VoidParameterServer[] shards = new VoidParameterServer[list.size()];
    for (int t = 0; t < shards.length; t++) {
        shards[t] = new VoidParameterServer(NodeRole.SHARD);
        Transport transport = new RoutedTransport();
        transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        shards[t].setShardIndex((short) t);
        shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
    }
    VoidParameterServer clientNode = new VoidParameterServer(NodeRole.CLIENT);
    RoutedTransport transport = new RoutedTransport();
    ClientRouter router = new InterleavedRouter(0);
    transport.setRouter(router);
    transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    router.init(voidConfiguration, transport);
    clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    final List<Long> times = new CopyOnWriteArrayList<>();
    // at this point, everything should be started, time for tests
    clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
    log.info("Initialization finished, going to tests...");
    Thread[] threads = new Thread[4];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 200; i++) {
                long time1 = System.nanoTime();
                INDArray array = clientNode.getVector(RandomUtils.nextInt(start, end));
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 100 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    // shutdown everything
    for (VoidParameterServer shard : shards) {
        shard.getTransport().shutdown();
    }
    clientNode.getTransport().shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicLong(java.util.concurrent.atomic.AtomicLong) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Test(org.junit.Test)

Example 7 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project deeplearning4j by deeplearning4j.

the class PartitionTrainingFunction method call.

@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
    /**
         * first we initialize
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm != null)
        elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm != null)
        sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
    // now we roll throw Sequences and prepare/convert/"learn" them
    while (sequenceIterator.hasNext()) {
        Sequence<T> sequence = sequenceIterator.next();
        Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
        for (T element : sequence.getElements()) {
            // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
            if (reduced != null)
                mergedSequence.addElement(reduced);
        }
        // do the same with labels, transfer them, if any
        if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
            for (T label : sequence.getSequenceLabels()) {
                ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
                if (reduced != null)
                    mergedSequence.addSequenceLabel(reduced);
            }
        }
        sequences.add(mergedSequence);
        if (sequences.size() >= 8) {
            trainAllAtOnce(sequences);
            sequences.clear();
        }
    }
    if (sequences.size() > 0) {
        // finishing training round, to make sure we don't have trails
        trainAllAtOnce(sequences);
        sequences.clear();
    }
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 8 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.

the class InterleavedRouterTest method setUp.

@Before
public void setUp() {
    configuration = VoidConfiguration.builder().shardAddresses(Arrays.asList("1.2.3.4", "2.3.4.5", "3.4.5.6", "4.5.6.7")).numberOfShards(// we set it manually here
    4).build();
    transport = new RoutedTransport();
    transport.setIpAndPort("8.9.10.11", 87312);
    originator = HashUtil.getLongHash(transport.getIp() + ":" + transport.getPort());
}
Also used : RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Before(org.junit.Before)

Example 9 with RoutedTransport

use of org.nd4j.parameterserver.distributed.transport.RoutedTransport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast2.

/**
 * This is second super-important test for unicast transport.
 * Here we send non-blocking messages
 */
@Test
@Ignore
public void testPerformanceUnicast2() {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 5; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
    VoidParameterServer[] shards = new VoidParameterServer[list.size()];
    for (int t = 0; t < shards.length; t++) {
        shards[t] = new VoidParameterServer(NodeRole.SHARD);
        Transport transport = new RoutedTransport();
        transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        shards[t].setShardIndex((short) t);
        shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
    }
    VoidParameterServer clientNode = new VoidParameterServer();
    RoutedTransport transport = new RoutedTransport();
    ClientRouter router = new InterleavedRouter(0);
    transport.setRouter(router);
    transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    router.init(voidConfiguration, transport);
    clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    final List<Long> times = new CopyOnWriteArrayList<>();
    // at this point, everything should be started, time for tests
    clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
    log.info("Initialization finished, going to tests...");
    Thread[] threads = new Thread[4];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 200; i++) {
                Frame<SkipGramRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
                for (int f = 0; f < 128; f++) {
                    frame.stackMessage(getSGRM());
                }
                long time1 = System.nanoTime();
                clientNode.execDistributed(frame);
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 100 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    // shutdown everything
    for (VoidParameterServer shard : shards) {
        shard.getTransport().shutdown();
    }
    clientNode.getTransport().shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) Frame(org.nd4j.parameterserver.distributed.messages.Frame) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)9 ArrayList (java.util.ArrayList)5 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)5 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)4 Test (org.junit.Test)4 MulticastTransport (org.nd4j.parameterserver.distributed.transport.MulticastTransport)4 Transport (org.nd4j.parameterserver.distributed.transport.Transport)4 List (java.util.List)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Frame (org.nd4j.parameterserver.distributed.messages.Frame)3 SkipGramTrainer (org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer)3 Counter (org.deeplearning4j.berkeley.Counter)2 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)2 ShallowSequenceElement (org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement)2 ClientRouter (org.nd4j.parameterserver.distributed.logic.ClientRouter)2 InterleavedRouter (org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)1 Pair (org.deeplearning4j.berkeley.Pair)1 DL4JInvalidConfigException (org.deeplearning4j.exception.DL4JInvalidConfigException)1