Search in sources :

Example 21 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class IteratorMultiDataSetIterator method next.

@Override
public MultiDataSet next(int num) {
    if (!hasNext())
        throw new NoSuchElementException();
    List<MultiDataSet> list = new ArrayList<>();
    int countSoFar = 0;
    while ((!queued.isEmpty() || iterator.hasNext()) && countSoFar < batchSize) {
        MultiDataSet next;
        if (!queued.isEmpty()) {
            next = queued.removeFirst();
        } else {
            next = iterator.next();
        }
        int nExamples = next.getFeatures(0).size(0);
        if (countSoFar + nExamples <= batchSize) {
            //Add the entire MultiDataSet as-is
            list.add(next);
        } else {
            //Split the MultiDataSet
            int nFeatures = next.numFeatureArrays();
            int nLabels = next.numLabelsArrays();
            INDArray[] fToKeep = new INDArray[nFeatures];
            INDArray[] lToKeep = new INDArray[nLabels];
            INDArray[] fToCache = new INDArray[nFeatures];
            INDArray[] lToCache = new INDArray[nLabels];
            INDArray[] fMaskToKeep = (next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null);
            INDArray[] lMaskToKeep = (next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null);
            INDArray[] fMaskToCache = (next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null);
            INDArray[] lMaskToCache = (next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null);
            for (int i = 0; i < nFeatures; i++) {
                INDArray fi = next.getFeatures(i);
                fToKeep[i] = getRange(fi, 0, batchSize - countSoFar);
                fToCache[i] = getRange(fi, batchSize - countSoFar, nExamples);
                if (fMaskToKeep != null) {
                    INDArray fmi = next.getFeaturesMaskArray(i);
                    fMaskToKeep[i] = getRange(fmi, 0, batchSize - countSoFar);
                    fMaskToCache[i] = getRange(fmi, batchSize - countSoFar, nExamples);
                }
            }
            for (int i = 0; i < nLabels; i++) {
                INDArray li = next.getLabels(i);
                lToKeep[i] = getRange(li, 0, batchSize - countSoFar);
                lToCache[i] = getRange(li, batchSize - countSoFar, nExamples);
                if (lMaskToKeep != null) {
                    INDArray lmi = next.getLabelsMaskArray(i);
                    lMaskToKeep[i] = getRange(lmi, 0, batchSize - countSoFar);
                    lMaskToCache[i] = getRange(lmi, batchSize - countSoFar, nExamples);
                }
            }
            MultiDataSet toKeep = new org.nd4j.linalg.dataset.MultiDataSet(fToKeep, lToKeep, fMaskToKeep, lMaskToKeep);
            MultiDataSet toCache = new org.nd4j.linalg.dataset.MultiDataSet(fToCache, lToCache, fMaskToCache, lMaskToCache);
            list.add(toKeep);
            queued.add(toCache);
        }
        countSoFar += nExamples;
    }
    MultiDataSet out;
    if (list.size() == 1) {
        out = list.get(0);
    } else {
        out = org.nd4j.linalg.dataset.MultiDataSet.merge(list);
    }
    if (preProcessor != null)
        preProcessor.preProcess(out);
    return out;
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 22 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class DataSetLossCalculatorCG method calculateScore.

@Override
public double calculateScore(ComputationGraph network) {
    double lossSum = 0.0;
    int exCount = 0;
    if (dataSetIterator != null) {
        dataSetIterator.reset();
        while (dataSetIterator.hasNext()) {
            DataSet dataSet = dataSetIterator.next();
            int nEx = dataSet.getFeatureMatrix().size(0);
            lossSum += network.score(dataSet) * nEx;
            exCount += nEx;
        }
    } else {
        multiDataSetIterator.reset();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet dataSet = multiDataSetIterator.next();
            int nEx = dataSet.getFeatures(0).size(0);
            lossSum += network.score(dataSet) * nEx;
            exCount += nEx;
        }
    }
    if (average)
        return lossSum / exCount;
    else
        return lossSum;
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet)

Example 23 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ComputationGraph method pretrainLayer.

/**
     * Pretrain a specified layer with the given MultiDataSetIterator
     *
     * @param layerName       Layer name
     * @param iter Training data
     */
public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
    if (!configuration.isPretrain())
        return;
    if (flattenedGradients == null)
        initGradientsView();
    if (!verticesMap.containsKey(layerName)) {
        throw new IllegalStateException("Invalid vertex name: " + layerName);
    }
    if (!verticesMap.get(layerName).hasLayer()) {
        //No op
        return;
    }
    int layerIndex = verticesMap.get(layerName).getVertexIndex();
    //Need to do partial forward pass. Simply folowing the topological ordering won't be efficient, as we might
    // end up doing forward pass on layers we don't need to.
    //However, we can start with the topological order, and prune out any layers we don't need to do
    LinkedList<Integer> partialTopoSort = new LinkedList<>();
    Set<Integer> seenSoFar = new HashSet<>();
    partialTopoSort.add(topologicalOrder[layerIndex]);
    seenSoFar.add(topologicalOrder[layerIndex]);
    for (int j = layerIndex - 1; j >= 0; j--) {
        //Do we need to do forward pass on this GraphVertex?
        //If it is input to any other layer we need, then yes. Otherwise: no
        VertexIndices[] outputsTo = vertices[topologicalOrder[j]].getOutputVertices();
        boolean needed = false;
        for (VertexIndices vi : outputsTo) {
            if (seenSoFar.contains(vi.getVertexIndex())) {
                needed = true;
                break;
            }
        }
        if (needed) {
            partialTopoSort.addFirst(topologicalOrder[j]);
            seenSoFar.add(topologicalOrder[j]);
        }
    }
    int[] fwdPassOrder = new int[partialTopoSort.size()];
    int k = 0;
    for (Integer g : partialTopoSort) fwdPassOrder[k++] = g;
    GraphVertex gv = vertices[fwdPassOrder[fwdPassOrder.length - 1]];
    Layer layer = gv.getLayer();
    if (!iter.hasNext() && iter.resetSupported()) {
        iter.reset();
    }
    while (iter.hasNext()) {
        MultiDataSet multiDataSet = iter.next();
        setInputs(multiDataSet.getFeatures());
        for (int j = 0; j < fwdPassOrder.length - 1; j++) {
            GraphVertex current = vertices[fwdPassOrder[j]];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = inputs[current.getVertexIndex()];
                for (VertexIndices v : inputsTo) {
                    int vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    //This input: the 'vIdxInputNum'th input to vertex 'vIdx'
                    //TODO When to dup?
                    vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
            } else {
                //Do forward pass:
                INDArray out = current.doForward(true);
                //Now, set the inputs for the next vertices:
                VertexIndices[] outputsTo = current.getOutputVertices();
                if (outputsTo != null) {
                    for (VertexIndices v : outputsTo) {
                        int vIdx = v.getVertexIndex();
                        int inputNum = v.getVertexEdgeNumber();
                        //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
                        vertices[vIdx].setInput(inputNum, out);
                    }
                }
            }
        }
        //At this point: have done all of the required forward pass stuff. Can now pretrain layer on current input
        layer.fit(gv.getInputs()[0]);
        layer.conf().setPretrain(false);
    }
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices)

Example 24 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ParallelWrapper method fit.

/**
     *
     * @param source
     */
public synchronized void fit(@NonNull MultiDataSetIterator source) {
    stopFit.set(false);
    if (zoo == null) {
        zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; cnt++) {
            // we pass true here, to tell Trainer to use MultiDataSet queue for training
            zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
            zoo[cnt].setUncaughtExceptionHandler(handler);
            zoo[cnt].start();
        }
    } else {
        for (int cnt = 0; cnt < workers; cnt++) {
            zoo[cnt].useMDS = true;
        }
    }
    source.reset();
    MultiDataSetIterator iterator;
    if (prefetchSize > 0 && source.asyncSupported()) {
        iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
    } else
        iterator = source;
    AtomicInteger locker = new AtomicInteger(0);
    while (iterator.hasNext() && !stopFit.get()) {
        MultiDataSet dataSet = iterator.next();
        if (dataSet == null)
            throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
        /*
             now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
            */
        int pos = locker.getAndIncrement();
        zoo[pos].feedMultiDataSet(dataSet);
        /*
                if all workers are dispatched now, join till all are finished
            */
        if (pos + 1 == workers || !iterator.hasNext()) {
            iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
                try {
                    zoo[cnt].waitTillRunning();
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
            /*
                    average model, and propagate it to whole
                */
            if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
                double score = getScore(locker);
                // averaging updaters state
                if (model instanceof ComputationGraph) {
                    averageUpdatersState(locker, score);
                } else
                    throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
                if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
                    for (int cnt = 0; cnt < workers; cnt++) {
                        zoo[cnt].updateModel(model);
                    }
                }
            }
            locker.set(0);
        }
    }
    // sanity checks, or the dataset may never average
    if (!wasAveraged)
        log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
    //            throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
    log.debug("Iterations passed: {}", iterationsCounter.get());
//        iterationsCounter.set(0);
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 25 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class MultiDataSetExportFunction method call.

@Override
public void call(Iterator<MultiDataSet> iter) throws Exception {
    String jvmuid = UIDProvider.getJVMUID();
    uid = Thread.currentThread().getId() + jvmuid.substring(0, Math.min(8, jvmuid.length()));
    while (iter.hasNext()) {
        MultiDataSet next = iter.next();
        String filename = "mds_" + uid + "_" + (outputCount++) + ".bin";
        String path = outputDir.getPath();
        URI uri = new URI(path + (path.endsWith("/") || path.endsWith("\\") ? "" : "/") + filename);
        FileSystem file = FileSystem.get(uri, conf);
        try (FSDataOutputStream out = file.create(new Path(uri))) {
            next.save(out);
        }
    }
}
Also used : Path(org.apache.hadoop.fs.Path) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSystem(org.apache.hadoop.fs.FileSystem) FSDataOutputStream(org.apache.hadoop.fs.FSDataOutputStream) URI(java.net.URI)

Aggregations

MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)28 Test (org.junit.Test)12 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)10 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)10 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)8 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)8 DataSet (org.nd4j.linalg.dataset.DataSet)8 FileSplit (org.datavec.api.split.FileSplit)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)6 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)6 RecordReader (org.datavec.api.records.reader.RecordReader)5 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)5 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)5 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)3 Random (java.util.Random)2