Search in sources :

Example 1 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ComputationGraphTestRNN method checkMaskArrayClearance.

@Test
public void checkMaskArrayClearance() {
    for (boolean tbptt : new boolean[] { true, false }) {
        //Simple "does it throw an exception" type test...
        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).graphBuilder().addInputs("in").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in").setOutputs("out").backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
        ComputationGraph net = new ComputationGraph(conf);
        net.init();
        MultiDataSet data = new MultiDataSet(new INDArray[] { Nd4j.linspace(1, 10, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.linspace(2, 20, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.ones(10) }, new INDArray[] { Nd4j.ones(10) });
        net.fit(data);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        DataSet ds = new DataSet(data.getFeatures(0), data.getLabels(0), data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
        net.fit(ds);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        MultiDataSetIterator iter = new IteratorMultiDataSetIterator(Collections.singletonList((org.nd4j.linalg.dataset.api.MultiDataSet) data).iterator(), 1);
        net.fit(iter);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
        DataSetIterator iter2 = new IteratorDataSetIterator(Collections.singletonList(ds).iterator(), 1);
        net.fit(iter2);
        assertNull(net.getInputMaskArrays());
        assertNull(net.getLabelMaskArrays());
        for (Layer l : net.getLayers()) {
            assertNull(l.getMaskArray());
        }
    }
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) Layer(org.deeplearning4j.nn.api.Layer) RnnOutputLayer(org.deeplearning4j.nn.conf.layers.RnnOutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) BaseRecurrentLayer(org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) IteratorDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorDataSetIterator) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) Test(org.junit.Test)

Example 2 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ComputationGraph method fit.

/**
     * Fit the ComputationGraph using a MultiDataSetIterator
     */
public void fit(MultiDataSetIterator multi) {
    if (flattenedGradients == null)
        initGradientsView();
    MultiDataSetIterator multiDataSetIterator;
    if (multi.asyncSupported()) {
        multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
    } else
        multiDataSetIterator = multi;
    if (configuration.isPretrain()) {
        pretrain(multiDataSetIterator);
    }
    if (configuration.isBackprop()) {
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet next = multiDataSetIterator.next();
            if (next.getFeatures() == null || next.getLabels() == null)
                break;
            if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
            } else {
                boolean hasMaskArrays = next.hasMaskArrays();
                if (hasMaskArrays) {
                    setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
                }
                setInputs(next.getFeatures());
                setLabels(next.getLabels());
                if (solver == null) {
                    solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
                }
                solver.optimize();
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
            }
            Nd4j.getMemoryManager().invokeGcOccasionally();
        }
    }
}
Also used : SingletonMultiDataSetIterator(org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) Solver(org.deeplearning4j.optimize.Solver) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 3 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ParallelWrapperMain method runMain.

public void runMain(String... args) throws Exception {
    JCommander jcmdr = new JCommander(this);
    try {
        jcmdr.parse(args);
    } catch (ParameterException e) {
        System.err.println(e.getMessage());
        //User provides invalid input -> print the usage info
        jcmdr.usage();
        try {
            Thread.sleep(500);
        } catch (Exception e2) {
        }
        System.exit(1);
    }
    Model model = ModelGuesser.loadModelGuess(modelPath);
    // ParallelWrapper will take care of load balancing between GPUs.
    ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(prefetchSize).workers(workers).averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters).reportScoreAfterAveraging(reportScore).useLegacyAveraging(legacyAveraging).build();
    if (dataSetIteratorFactoryClazz != null) {
        DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
        DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(dataSetIterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else if (multiDataSetIteratorFactoryClazz != null) {
        MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
        MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(iterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else {
        throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
    }
}
Also used : ParallelWrapper(org.deeplearning4j.parallelism.ParallelWrapper) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) StatsListener(org.deeplearning4j.ui.stats.StatsListener) ParameterException(com.beust.jcommander.ParameterException) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) JCommander(com.beust.jcommander.JCommander) Model(org.deeplearning4j.nn.api.Model) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) ParameterException(com.beust.jcommander.ParameterException) File(java.io.File) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)

Example 4 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.

@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();
        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();
        assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
        assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) Random(java.util.Random) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 5 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSV.

@Test
public void testSplittingCSV() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    //need to manually extract
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
    RecordReader rr2 = new CSVRecordReader(0, ",");
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = rrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(2, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        //Get the subsets of the original iris data
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(0));
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 2, true));
        INDArray expOut1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(3));
        INDArray expOut2 = lds;
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(expOut1, lmds[0]);
        assertEquals(expOut2, lmds[1]);
    }
    assertFalse(rrmdsi.hasNext());
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)17 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)10 Test (org.junit.Test)9 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)8 FileSplit (org.datavec.api.split.FileSplit)7 DataSet (org.nd4j.linalg.dataset.DataSet)7 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)6 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)6 RecordReader (org.datavec.api.records.reader.RecordReader)5 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)5 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)4 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IteratorMultiDataSetIterator (org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator)3 ParentPathLabelGenerator (org.datavec.api.io.labels.ParentPathLabelGenerator)2 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)2 RecordReaderMultiDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator)2 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)2