Search in sources :

Example 1 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a MultiDataSetIterator
     */
public void fit(MultiDataSetIterator multi) {
    if (flattenedGradients == null)
        initGradientsView();
    MultiDataSetIterator multiDataSetIterator;
    if (multi.asyncSupported()) {
        multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
    } else
        multiDataSetIterator = multi;
    if (configuration.isPretrain()) {
        pretrain(multiDataSetIterator);
    }
    if (configuration.isBackprop()) {
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet next = multiDataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
            } else {
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                }
                setInputs(next.getFeatures());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
}
Also used : SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) Solver(org.deeplearning4j.optimize.Solver) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 2 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class ComputationGraphUtil method toMultiDataSet.

/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
    INDArray f = dataSet.getFeatures();
    INDArray l = dataSet.getLabels();
    INDArray fMask = dataSet.getFeaturesMaskArray();
    INDArray lMask = dataSet.getLabelsMaskArray();
    INDArray[] fNew = f == null ? null : new INDArray[] { f };
    INDArray[] lNew = l == null ? null : new INDArray[] { l };
    INDArray[] fMaskNew = (fMask != null ? new INDArray[] { fMask } : null);
    INDArray[] lMaskNew = (lMask != null ? new INDArray[] { lMask } : null);
    return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet)

Example 3 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method nextMultiDataSet.

private MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals, Map<String, List<List<List<Writable>>>> nextSeqRRVals, List<RecordMetaDataComposableMap> nextMetas) {
    int minExamples = Integer.MAX_VALUE;
    for (List<List<Writable>> exampleData : nextRRVals.values()) {
        minExamples = Math.min(minExamples, exampleData.size());
    }
    for (List<List<List<Writable>>> exampleData : nextSeqRRVals.values()) {
        minExamples = Math.min(minExamples, exampleData.size());
    }
    if (minExamples == Integer.MAX_VALUE)
        //Should never happen
        throw new RuntimeException("Error occurred during data set generation: no readers?");
    //In order to align data at the end (for each example individually), we need to know the length of the
    // longest time series for each example
    int[] longestSequence = null;
    if (alignmentMode == AlignmentMode.ALIGN_END) {
        longestSequence = new int[minExamples];
        for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
            List<List<List<Writable>>> list = entry.getValue();
            for (int i = 0; i < list.size() && i < minExamples; i++) {
                longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
            }
        }
    }
    //Second: create the input arrays
    //To do this, we need to know longest time series length, so we can do padding
    int longestTS = -1;
    if (alignmentMode != AlignmentMode.EQUAL_LENGTH) {
        for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
            List<List<List<Writable>>> list = entry.getValue();
            for (List<List<Writable>> c : list) {
                longestTS = Math.max(longestTS, c.size());
            }
        }
    }
    INDArray[] inputArrs = new INDArray[inputs.size()];
    INDArray[] inputArrMasks = new INDArray[inputs.size()];
    boolean inputMasks = false;
    int i = 0;
    for (SubsetDetails d : inputs) {
        if (nextRRVals.containsKey(d.readerName)) {
            //Standard reader
            List<List<Writable>> list = nextRRVals.get(d.readerName);
            inputArrs[i] = convertWritables(list, minExamples, d);
        } else {
            //Sequence reader
            List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
            Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
            inputArrs[i] = p.getFirst();
            inputArrMasks[i] = p.getSecond();
            if (inputArrMasks[i] != null)
                inputMasks = true;
        }
        i++;
    }
    if (!inputMasks)
        inputArrMasks = null;
    //Third: create the outputs
    INDArray[] outputArrs = new INDArray[outputs.size()];
    INDArray[] outputArrMasks = new INDArray[outputs.size()];
    boolean outputMasks = false;
    i = 0;
    for (SubsetDetails d : outputs) {
        if (nextRRVals.containsKey(d.readerName)) {
            //Standard reader
            List<List<Writable>> list = nextRRVals.get(d.readerName);
            outputArrs[i] = convertWritables(list, minExamples, d);
        } else {
            //Sequence reader
            List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
            Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
            outputArrs[i] = p.getFirst();
            outputArrMasks[i] = p.getSecond();
            if (outputArrMasks[i] != null)
                outputMasks = true;
        }
        i++;
    }
    if (!outputMasks)
        outputArrMasks = null;
    MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(inputArrs, outputArrs, inputArrMasks, outputArrMasks);
    if (collectMetaData) {
        mds.setExampleMetaData(nextMetas);
    }
    if (preProcessor != null)
        preProcessor.preProcess(mds);
    return mds;
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 4 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.

@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();
        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();
        assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
        assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) Random(java.util.Random) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 5 with MultiDataSet

use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVMeta.

@Test
public void testSplittingCSVMeta() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    RecordReader rr2 = new CSVRecordReader(0, ",");
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    rrmdsi.setCollectMetaData(true);
    int count = 0;
    while (rrmdsi.hasNext()) {
        MultiDataSet mds = rrmdsi.next();
        MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(150 / 10, count);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)28 Test (org.junit.Test)12 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)11 INDArray (org.nd4j.linalg.api.ndarray.INDArray)10 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)10 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)8 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)8 DataSet (org.nd4j.linalg.dataset.DataSet)8 FileSplit (org.datavec.api.split.FileSplit)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)6 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)6 RecordReader (org.datavec.api.records.reader.RecordReader)5 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)5 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)5 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)4 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)3 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)3 Random (java.util.Random)2