Search in sources :

Example 1 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class WordVectorSerializer method readWord2VecModel.

/**
     * This method
     * 1) Binary model, either compressed or not. Like well-known Google Model
     * 2) Popular CSV word2vec text format
     * 3) DL4j compressed format
     *
     * Please note: if extended data isn't available, only weights will be loaded instead.
     *
     * @param file
     * @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
     * @return
     */
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
    InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable<>();
    AbstractCache<VocabWord> vocabCache = new AbstractCache<>();
    Word2Vec vec;
    INDArray syn0 = null;
    VectorsConfiguration configuration = new VectorsConfiguration();
    if (!file.exists() || !file.isFile())
        throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
    int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
    boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
    if (originalPeriodic)
        Nd4j.getMemoryManager().togglePeriodicGc(false);
    Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
    // try to load zip format
    try {
        if (extendedModel) {
            log.debug("Trying full model restoration...");
            if (originalPeriodic)
                Nd4j.getMemoryManager().togglePeriodicGc(true);
            Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            return readWord2Vec(file);
        } else {
            log.debug("Trying simplified model restoration...");
            File tmpFileSyn0 = File.createTempFile("word2vec", "syn");
            File tmpFileConfig = File.createTempFile("word2vec", "config");
            // we don't need full model, so we go directly to syn0 file
            ZipFile zipFile = new ZipFile(file);
            ZipEntry syn = zipFile.getEntry("syn0.txt");
            InputStream stream = zipFile.getInputStream(syn);
            Files.copy(stream, Paths.get(tmpFileSyn0.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
            // now we're restoring configuration saved earlier
            ZipEntry config = zipFile.getEntry("config.json");
            if (config != null) {
                stream = zipFile.getInputStream(config);
                StringBuilder builder = new StringBuilder();
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                    String line;
                    while ((line = reader.readLine()) != null) {
                        builder.append(line);
                    }
                }
                configuration = VectorsConfiguration.fromJson(builder.toString().trim());
            }
            ZipEntry ve = zipFile.getEntry("frequencies.txt");
            if (ve != null) {
                stream = zipFile.getInputStream(ve);
                AtomicInteger cnt = new AtomicInteger(0);
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
                    String line;
                    while ((line = reader.readLine()) != null) {
                        String[] split = line.split(" ");
                        VocabWord word = new VocabWord(Double.valueOf(split[1]), decodeB64(split[0]));
                        word.setIndex(cnt.getAndIncrement());
                        word.incrementSequencesCount(Long.valueOf(split[2]));
                        vocabCache.addToken(word);
                        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                        Nd4j.getMemoryManager().invokeGcOccasionally();
                    }
                }
            }
            List<INDArray> rows = new ArrayList<>();
            // basically read up everything, call vstacl and then return model
            try (Reader reader = new CSVReader(tmpFileSyn0)) {
                AtomicInteger cnt = new AtomicInteger(0);
                while (reader.hasNext()) {
                    Pair<VocabWord, float[]> pair = reader.next();
                    VocabWord word = pair.getFirst();
                    INDArray vector = Nd4j.create(pair.getSecond());
                    if (ve != null) {
                        if (syn0 == null)
                            syn0 = Nd4j.create(vocabCache.numWords(), vector.length());
                        syn0.getRow(cnt.getAndIncrement()).assign(vector);
                    } else {
                        rows.add(vector);
                        vocabCache.addToken(word);
                        vocabCache.addWordToIndex(word.getIndex(), word.getLabel());
                    }
                    Nd4j.getMemoryManager().invokeGcOccasionally();
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
            }
            if (syn0 == null && vocabCache.numWords() > 0)
                syn0 = Nd4j.vstack(rows);
            if (syn0 == null) {
                log.error("Can't build syn0 table");
                throw new DL4JInvalidInputException("Can't build syn0 table");
            }
            lookupTable = new InMemoryLookupTable.Builder<VocabWord>().cache(vocabCache).vectorLength(syn0.columns()).useHierarchicSoftmax(false).useAdaGrad(false).build();
            lookupTable.setSyn0(syn0);
            try {
                tmpFileSyn0.delete();
                tmpFileConfig.delete();
            } catch (Exception e) {
            //
            }
        }
    } catch (Exception e) {
        // let's try to load this file as csv file
        try {
            log.debug("Trying CSV model restoration...");
            Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
            lookupTable = pair.getFirst();
            vocabCache = (AbstractCache<VocabWord>) pair.getSecond();
        } catch (Exception ex) {
            // we fallback to trying binary model instead
            try {
                log.debug("Trying binary model restoration...");
                if (originalPeriodic)
                    Nd4j.getMemoryManager().togglePeriodicGc(true);
                Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
                vec = loadGoogleModel(file, true, true);
                return vec;
            } catch (Exception ey) {
                // try to load without linebreaks
                try {
                    if (originalPeriodic)
                        Nd4j.getMemoryManager().togglePeriodicGc(true);
                    Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
                    vec = loadGoogleModel(file, true, false);
                    return vec;
                } catch (Exception ez) {
                    throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
                }
            }
        }
    }
    Word2Vec.Builder builder = new Word2Vec.Builder(configuration).lookupTable(lookupTable).useAdaGrad(false).vocabCache(vocabCache).layerSize(lookupTable.layerSize()).useHierarchicSoftmax(false).resetModel(false);
    /*
            Trying to restore TokenizerFactory & TokenPreProcessor
         */
    TokenizerFactory factory = getTokenizerFactory(configuration);
    if (factory != null)
        builder.tokenizerFactory(factory);
    vec = builder.build();
    return vec;
}
Also used : ZipEntry(java.util.zip.ZipEntry) ArrayList(java.util.ArrayList) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) AbstractCache(org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) StaticWord2Vec(org.deeplearning4j.models.word2vec.StaticWord2Vec) Word2Vec(org.deeplearning4j.models.word2vec.Word2Vec) Pair(org.deeplearning4j.berkeley.Pair) TokenizerFactory(org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory) GZIPInputStream(java.util.zip.GZIPInputStream) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ZipFile(java.util.zip.ZipFile) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ZipFile(java.util.zip.ZipFile) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException)

Example 2 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     * This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
     *
     * @param source
     */
public synchronized void fit(@NonNull DataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
            // if if we're using MQ here - we'd like
            if (isMQ)
                Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    }
    source.reset();
    DataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        if (isMQ) {
            if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
                log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
            MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
            iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
        } else
            iterator = new AsyncDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    int whiles = 0;
    while (iterator.hasNext() && !stopFit.get()) {
        whiles++;
        DataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as DataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        if (zoo == null)
            throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
        zoo[pos].feedDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof MultiLayerNetwork) {
                    if (averageUpdaters) {
                        Updater updater = ((MultiLayerNetwork) model).getUpdater();
                        int batchSize = 0;
                        if (updater != null && updater.getStateViewArray() != null) {
                            if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
                                List<INDArray> updaters = new ArrayList<>();
                                for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    updaters.add(workerModel.getUpdater().getStateViewArray());
                                    batchSize += workerModel.batchSize();
                                }
                                Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
                            } else {
                                INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
                                int cnt = 0;
                                for (; cnt < workers && cnt < locker.get(); cnt++) {
                                    MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
                                    state.addi(workerModel.getUpdater().getStateViewArray().dup());
                                    batchSize += workerModel.batchSize();
                                }
                                state.divi(cnt);
                                updater.setStateViewArray((MultiLayerNetwork) model, state, false);
                            }
                        }
                    }
                    ((MultiLayerNetwork) model).setScore(score);
                } else if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                }
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataSet(org.nd4j.linalg.dataset.api.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 3 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class ParagraphVectors method inferVector.

/**
     * This method calculates inferred vector for given text
     *
     * @param text
     * @return
     */
public INDArray inferVector(String text, double learningRate, double minLearningRate, int iterations) {
    if (tokenizerFactory == null)
        throw new IllegalStateException("TokenizerFactory should be defined, prior to predict() call");
    if (this.vocab == null || this.vocab.numWords() == 0)
        reassignExistingModel();
    List<String> tokens = tokenizerFactory.create(text).getTokens();
    List<VocabWord> document = new ArrayList<>();
    for (String token : tokens) {
        if (vocab.containsWord(token)) {
            document.add(vocab.wordFor(token));
        }
    }
    if (document.isEmpty())
        throw new ND4JIllegalStateException("Text passed for inference has no matches in model vocabulary.");
    return inferVector(document, learningRate, minLearningRate, iterations);
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 4 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class TrainingFunction method call.

@Override
@SuppressWarnings("unchecked")
public void call(Sequence<T> sequence) throws Exception {
    /**
         * Depending on actual training mode, we'll either go for SkipGram/CBOW/PV-DM/PV-DBOW or whatever
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    /*
         at this moment we should have everything ready for actual initialization
         the only limitation we have - our sequence is detached from actual vocabulary, so we need to merge it back virtually
        */
    Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
    for (T element : sequence.getElements()) {
        // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
        ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
        if (reduced != null)
            mergedSequence.addElement(reduced);
    }
    // do the same with labels, transfer them, if any
    if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
        for (T label : sequence.getSequenceLabels()) {
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
            if (reduced != null)
                mergedSequence.addSequenceLabel(reduced);
        }
    }
    // FIXME: temporary hook
    if (sequence.size() > 0)
        paramServer.execDistributed(elementsLearningAlgorithm.frameSequence(mergedSequence, new AtomicLong(119), 25e-3));
    else
        log.warn("Skipping empty sequence...");
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 5 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class Nd4j method create.

public static INDArray create(double[] data, int[] shape, char ordering, long offset) {
    shape = getEnsuredShape(shape);
    if (shape.length == 1) {
        if (shape[0] != data.length)
            throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length);
    }
    checkShapeValues(data.length, shape);
    INDArray ret = INSTANCE.create(data, shape, getStrides(shape, ordering), offset, ordering);
    logCreationIfNecessary(ret);
    return ret;
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3