Search in sources :

Example 1 with InputSplit

use of org.datavec.api.split.InputSplit in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunction.

@Test
public void testDataVecDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/imagetest/0/a.bmp");
    //Need this for Spark: can't infer without init call
    List<String> labelsList = Arrays.asList("0", "1");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 7);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //4 images
    assertEquals(4, origData.count());
    ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    rr.setLabels(labelsList);
    org.datavec.spark.functions.RecordReaderFunction rrf = new org.datavec.spark.functions.RecordReaderFunction(rr);
    JavaRDD<List<Writable>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecDataSetFunction(1, 2, false));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "bmp" }, true);
    ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    irr.initialize(is);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(irr, 1, 1, 2);
    List<DataSet> listLocal = new ArrayList<>(4);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(4, collected.size());
    assertEquals(4, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 4; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 4 and exactly 4 pairwise matches between spark and local versions
    assertEquals(4, count);
}
Also used : SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) File(java.io.File) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 2 with InputSplit

use of org.datavec.api.split.InputSplit in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecSequenceDataSetFunction.

@Test
public void testDataVecSequenceDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 17);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //3 CSV sequences
    assertEquals(3, origData.count());
    SequenceRecordReader seqRR = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReaderFunction rrf = new SequenceRecordReaderFunction(seqRR);
    JavaRDD<List<List<Writable>>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecSequenceDataSetFunction(2, -1, true, null, null));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "txt" }, true);
    SequenceRecordReader seqRR2 = new CSVSequenceRecordReader(1, ",");
    seqRR2.initialize(is);
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(seqRR2, 1, -1, 2, true);
    List<DataSet> listLocal = new ArrayList<>(3);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(3, collected.size());
    assertEquals(3, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    assertEquals(3, count);
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) File(java.io.File) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

File (java.io.File)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 PortableDataStream (org.apache.spark.input.PortableDataStream)2 FileSplit (org.datavec.api.split.FileSplit)2 InputSplit (org.datavec.api.split.InputSplit)2 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)2 SequenceRecordReaderFunction (org.datavec.spark.functions.SequenceRecordReaderFunction)2 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)2 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)2 Test (org.junit.Test)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 ParentPathLabelGenerator (org.datavec.api.io.labels.ParentPathLabelGenerator)1 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)1 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)1 Writable (org.datavec.api.writable.Writable)1 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)1 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)1