Search in sources :

Example 1 with PortableDataStream

use of org.apache.spark.input.PortableDataStream in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunction.

@Test
public void testDataVecDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/imagetest/0/a.bmp");
    //Need this for Spark: can't infer without init call
    List<String> labelsList = Arrays.asList("0", "1");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 7);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //4 images
    assertEquals(4, origData.count());
    ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    rr.setLabels(labelsList);
    org.datavec.spark.functions.RecordReaderFunction rrf = new org.datavec.spark.functions.RecordReaderFunction(rr);
    JavaRDD<List<Writable>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecDataSetFunction(1, 2, false));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "bmp" }, true);
    ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    irr.initialize(is);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(irr, 1, 1, 2);
    List<DataSet> listLocal = new ArrayList<>(4);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(4, collected.size());
    assertEquals(4, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 4; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 4 and exactly 4 pairwise matches between spark and local versions
    assertEquals(4, count);
}
Also used : SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) File(java.io.File) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 2 with PortableDataStream

use of org.apache.spark.input.PortableDataStream in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecSequenceDataSetFunction.

@Test
public void testDataVecSequenceDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 17);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //3 CSV sequences
    assertEquals(3, origData.count());
    SequenceRecordReader seqRR = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReaderFunction rrf = new SequenceRecordReaderFunction(seqRR);
    JavaRDD<List<List<Writable>>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecSequenceDataSetFunction(2, -1, true, null, null));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "txt" }, true);
    SequenceRecordReader seqRR2 = new CSVSequenceRecordReader(1, ",");
    seqRR2.initialize(is);
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(seqRR2, 1, -1, 2, true);
    List<DataSet> listLocal = new ArrayList<>(3);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(3, collected.size());
    assertEquals(3, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    assertEquals(3, count);
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) File(java.io.File) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 3 with PortableDataStream

use of org.apache.spark.input.PortableDataStream in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testCsvPreprocessedDataGenerationNoLabel.

@Test
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
    //Same as above test, but without any labels (in which case: input and output arrays are the same)
    List<String> list = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
    }
    JavaRDD<String> rdd = sc.parallelize(list);
    int partitions = rdd.partitions().size();
    URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
    URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData3");
    File temp = new File(outputDir.getPath());
    if (temp.exists())
        FileUtils.deleteDirectory(temp);
    int numBinFiles = 0;
    try {
        int batchSize = 5;
        int labelIdx = -1;
        int numPossibleLabels = -1;
        rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
        File[] fileList = new File(outputDir.getPath()).listFiles();
        int totalExamples = 0;
        for (File f2 : fileList) {
            if (!f2.getPath().endsWith(".bin"))
                continue;
            //                System.out.println(f2.getPath());
            numBinFiles++;
            DataSet ds = new DataSet();
            ds.load(f2);
            assertEquals(5, ds.numInputs());
            assertEquals(5, ds.numOutcomes());
            totalExamples += ds.numExamples();
        }
        assertEquals(150, totalExamples);
        //Expect 30, give or take due to partitioning randomness
        assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
        //Test the PortableDataStreamDataSetIterator:
        JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
        List<PortableDataStream> pdsList = pds.values().collect();
        DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
        int pdsCount = 0;
        int totalExamples2 = 0;
        while (pdsIter.hasNext()) {
            DataSet ds = pdsIter.next();
            pdsCount++;
            totalExamples2 += ds.numExamples();
            assertEquals(5, ds.numInputs());
            assertEquals(5, ds.numOutcomes());
        }
        assertEquals(150, totalExamples2);
        assertEquals(numBinFiles, pdsCount);
    } finally {
        FileUtils.deleteDirectory(temp);
    }
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) URI(java.net.URI) StringToDataSetExportFunction(org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 4 with PortableDataStream

use of org.apache.spark.input.PortableDataStream in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testCsvPreprocessedDataGeneration.

@Test
public void testCsvPreprocessedDataGeneration() throws Exception {
    List<String> list = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
    }
    JavaRDD<String> rdd = sc.parallelize(list);
    int partitions = rdd.partitions().size();
    URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
    URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData2");
    File temp = new File(outputDir.getPath());
    if (temp.exists())
        FileUtils.deleteDirectory(temp);
    int numBinFiles = 0;
    try {
        int batchSize = 5;
        int labelIdx = 4;
        int numPossibleLabels = 3;
        rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
        File[] fileList = new File(outputDir.getPath()).listFiles();
        int totalExamples = 0;
        for (File f2 : fileList) {
            if (!f2.getPath().endsWith(".bin"))
                continue;
            //                System.out.println(f2.getPath());
            numBinFiles++;
            DataSet ds = new DataSet();
            ds.load(f2);
            assertEquals(4, ds.numInputs());
            assertEquals(3, ds.numOutcomes());
            totalExamples += ds.numExamples();
        }
        assertEquals(150, totalExamples);
        //Expect 30, give or take due to partitioning randomness
        assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
        //Test the PortableDataStreamDataSetIterator:
        JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
        List<PortableDataStream> pdsList = pds.values().collect();
        DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
        int pdsCount = 0;
        int totalExamples2 = 0;
        while (pdsIter.hasNext()) {
            DataSet ds = pdsIter.next();
            pdsCount++;
            totalExamples2 += ds.numExamples();
            assertEquals(4, ds.numInputs());
            assertEquals(3, ds.numOutcomes());
        }
        assertEquals(150, totalExamples2);
        assertEquals(numBinFiles, pdsCount);
    } finally {
        FileUtils.deleteDirectory(temp);
    }
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) URI(java.net.URI) StringToDataSetExportFunction(org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 5 with PortableDataStream

use of org.apache.spark.input.PortableDataStream in project deeplearning4j by deeplearning4j.

the class PortableDataStreamMultiDataSetIterator method next.

@Override
public MultiDataSet next() {
    MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
    PortableDataStream pds = iter.next();
    try (InputStream is = pds.open()) {
        ds.load(is);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    if (preprocessor != null)
        preprocessor.preProcess(ds);
    return ds;
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) InputStream(java.io.InputStream) PortableDataStream(org.apache.spark.input.PortableDataStream) IOException(java.io.IOException)

Aggregations

PortableDataStream (org.apache.spark.input.PortableDataStream)5 File (java.io.File)4 ArrayList (java.util.ArrayList)4 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)4 Test (org.junit.Test)4 DataSet (org.nd4j.linalg.dataset.DataSet)4 URI (java.net.URI)2 List (java.util.List)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)2 FileSplit (org.datavec.api.split.FileSplit)2 InputSplit (org.datavec.api.split.InputSplit)2 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)2 SequenceRecordReaderFunction (org.datavec.spark.functions.SequenceRecordReaderFunction)2 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)2 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)2 StringToDataSetExportFunction (org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction)2 PortableDataStreamDataSetIterator (org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator)2 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2