Search in sources :

Example 1 with RecordReaderDataSetIterator

use of org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRDSI.

@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    CollectionRecordReader crr = new CollectionRecordReader(c);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);
    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRange(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) Collection(java.util.Collection) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) DL4JException(org.deeplearning4j.exception.DL4JException) IntWritable(org.datavec.api.writable.IntWritable) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 2 with RecordReaderDataSetIterator

use of org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator in project deeplearning4j by deeplearning4j.

the class DataSetIteratorTest method testMnist.

@Test
public void testMnist() throws Exception {
    ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
    CSVRecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
    RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
    MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
    while (dsi.hasNext()) {
        DataSet dsExp = dsi.next();
        DataSet dsAct = iter.next();
        INDArray fExp = dsExp.getFeatureMatrix();
        fExp.divi(255);
        INDArray lExp = dsExp.getLabels();
        INDArray fAct = dsAct.getFeatureMatrix();
        INDArray lAct = dsAct.getLabels();
        assertEquals(fExp, fAct);
        assertEquals(lExp, lAct);
    }
    assertFalse(iter.hasNext());
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 3 with RecordReaderDataSetIterator

use of org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testNextAndReset.

@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 4 with RecordReaderDataSetIterator

use of org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testLoadFullDataSet.

@Test
public void testLoadFullDataSet() throws Exception {
    int epochs = 3;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(50);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertEquals(path.numExamples(), 50, 0.0);
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 5 with RecordReaderDataSetIterator

use of org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunction.

@Test
public void testDataVecDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/imagetest/0/a.bmp");
    //Need this for Spark: can't infer without init call
    List<String> labelsList = Arrays.asList("0", "1");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 7);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //4 images
    assertEquals(4, origData.count());
    ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    rr.setLabels(labelsList);
    org.datavec.spark.functions.RecordReaderFunction rrf = new org.datavec.spark.functions.RecordReaderFunction(rr);
    JavaRDD<List<Writable>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecDataSetFunction(1, 2, false));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "bmp" }, true);
    ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    irr.initialize(is);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(irr, 1, 1, 2);
    List<DataSet> listLocal = new ArrayList<>(4);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(4, collected.size());
    assertEquals(4, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 4; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 4 and exactly 4 pairwise matches between spark and local versions
    assertEquals(4, count);
}
Also used : SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) File(java.io.File) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)10 FileSplit (org.datavec.api.split.FileSplit)8 Test (org.junit.Test)8 DataSet (org.nd4j.linalg.dataset.DataSet)8 RecordReader (org.datavec.api.records.reader.RecordReader)7 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)6 ClassPathResource (org.datavec.api.util.ClassPathResource)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 ArrayList (java.util.ArrayList)3 CifarDataSetIterator (org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator)3 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 File (java.io.File)2 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)2 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)2 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)2 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 NormalizerStandardize (org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize)2