Search in sources :

Example 61 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class TrainerProvider method scanClasspath.

protected void scanClasspath() {
    // TODO: reflection stuff to fill trainers
    Reflections reflections = new Reflections("org");
    Set<Class<? extends TrainingDriver>> classes = reflections.getSubTypesOf(TrainingDriver.class);
    for (Class clazz : classes) {
        if (clazz.isInterface() || Modifier.isAbstract(clazz.getModifiers()))
            continue;
        try {
            TrainingDriver driver = (TrainingDriver) clazz.newInstance();
            trainers.put(driver.targetMessageClass(), driver);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (trainers.size() < 1)
        throw new ND4JIllegalStateException("No TrainingDrivers were found");
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Reflections(org.reflections.Reflections)

Example 62 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class BaseTransport method launch.

/**
 * This method starts transport mechanisms.
 *
 * PLEASE NOTE: init() method should be called prior to launch() call
 */
@Override
public void launch(@NonNull ThreadingModel threading) {
    this.threadingModel = threading;
    switch(threading) {
        case SINGLE_THREAD:
            {
                log.warn("SINGLE_THREAD model is used, performance will be significantly reduced");
                // single thread for all queues. shouldn't be used in real world
                threadA = new Thread(() -> {
                    while (runner.get()) {
                        if (subscriptionForShards != null)
                            subscriptionForShards.poll(messageHandlerForShards, 512);
                        idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
                    }
                });
                threadA.start();
            }
            break;
        case DEDICATED_THREADS:
            {
                // we start separate thread for each handler
                /**
                 * We definitely might use less conditional code here, BUT i'll keep it as is,
                 * only because we want code to be obvious for people
                 */
                final AtomicBoolean localRunner = new AtomicBoolean(false);
                if (nodeRole == NodeRole.NONE) {
                    throw new ND4JIllegalStateException("No role is set for current node!");
                } else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) {
                    // setting up thread for shard->client communication listener
                    if (messageHandlerForShards != null)
                        threadB = new Thread(() -> {
                            while (runner.get()) idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512));
                        });
                    // setting up thread for inter-shard communication listener
                    threadA = new Thread(() -> {
                        localRunner.set(true);
                        while (runner.get()) idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
                    });
                    if (threadB != null) {
                        Nd4j.getAffinityManager().attachThreadToDevice(threadB, Nd4j.getAffinityManager().getDeviceForCurrentThread());
                        threadB.setDaemon(true);
                        threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]");
                        threadB.start();
                    }
                } else {
                    // setting up thread for shard->client communication listener
                    threadA = new Thread(() -> {
                        localRunner.set(true);
                        while (runner.get()) idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512));
                    });
                }
                // all roles have threadA anyway
                Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread());
                threadA.setDaemon(true);
                threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]");
                threadA.start();
                while (!localRunner.get()) try {
                    Thread.sleep(50);
                } catch (Exception e) {
                }
            }
            break;
        case SAME_THREAD:
            {
                // no additional threads at all, we do poll within takeMessage loop
                log.warn("SAME_THREAD model is used, performance will be dramatically reduced");
            }
            break;
        default:
            throw new IllegalStateException("Unknown thread model: [" + threading.toString() + "]");
    }
}
Also used : AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 63 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class InMemoryLookupTable method consume.

/**
     * This method consumes weights of a given InMemoryLookupTable
     *
     * PLEASE NOTE: this method explicitly resets current weights
     *
     * @param srcTable
     */
public void consume(InMemoryLookupTable<T> srcTable) {
    if (srcTable.vectorLength != this.vectorLength)
        throw new IllegalStateException("You can't consume lookupTable with different vector lengths");
    if (srcTable.syn0 == null)
        throw new IllegalStateException("Source lookupTable Syn0 is NULL");
    this.resetWeights(true);
    AtomicInteger cntHs = new AtomicInteger(0);
    AtomicInteger cntNg = new AtomicInteger(0);
    if (srcTable.syn0.rows() > this.syn0.rows())
        throw new IllegalStateException("You can't consume lookupTable with built for larger vocabulary without updating your vocabulary first");
    for (int x = 0; x < srcTable.syn0.rows(); x++) {
        this.syn0.putRow(x, srcTable.syn0.getRow(x));
        if (this.syn1 != null && srcTable.syn1 != null)
            this.syn1.putRow(x, srcTable.syn1.getRow(x));
        else if (cntHs.incrementAndGet() == 1)
            log.info("Skipping syn1 merge");
        if (this.syn1Neg != null && srcTable.syn1Neg != null) {
            this.syn1Neg.putRow(x, srcTable.syn1Neg.getRow(x));
        } else if (cntNg.incrementAndGet() == 1)
            log.info("Skipping syn1Neg merge");
        if (cntHs.get() > 0 && cntNg.get() > 0)
            throw new ND4JIllegalStateException("srcTable has no syn1/syn1neg");
    }
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 64 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method inferVector.

/**
     * This method calculates inferred vector for given document
     *
     * @param document
     * @return
     */
public INDArray inferVector(@NonNull List<VocabWord> document, double learningRate, double minLearningRate, int iterations) {
    SequenceLearningAlgorithm<VocabWord> learner = sequenceLearningAlgorithm;
    if (learner == null) {
        synchronized (this) {
            if (sequenceLearningAlgorithm == null) {
                log.info("Creating new PV-DM learner...");
                learner = new DM<VocabWord>();
                learner.configure(vocab, lookupTable, configuration);
                sequenceLearningAlgorithm = learner;
            } else {
                learner = sequenceLearningAlgorithm;
            }
        }
    }
    learner = sequenceLearningAlgorithm;
    if (document.isEmpty())
        throw new ND4JIllegalStateException("Impossible to apply inference to empty list of words");
    Sequence<VocabWord> sequence = new Sequence<>();
    sequence.addElements(document);
    sequence.setSequenceLabel(new VocabWord(1.0, String.valueOf(new Random().nextInt())));
    initLearners();
    INDArray inf = learner.inferSequence(sequence, seed, learningRate, minLearningRate, iterations);
    return inf;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence)

Example 65 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     *
     * @param source
     */
public synchronized void fit(@NonNull MultiDataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            // we pass true here, to tell Trainer to use MultiDataSet queue for training
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    } else {
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt].useMDS = true;
        }
    }
    source.reset();
    MultiDataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    while (iterator.hasNext() && !stopFit.get()) {
        MultiDataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        zoo[pos].feedMultiDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                } else
                    throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3