Search in sources :

Example 1 with PortableDataStreamDataSetIterator

use of org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testCsvPreprocessedDataGenerationNoLabel.

@Test
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
    //Same as above test, but without any labels (in which case: input and output arrays are the same)
    List<String> list = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
    }
    JavaRDD<String> rdd = sc.parallelize(list);
    int partitions = rdd.partitions().size();
    URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
    URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData3");
    File temp = new File(outputDir.getPath());
    if (temp.exists())
        FileUtils.deleteDirectory(temp);
    int numBinFiles = 0;
    try {
        int batchSize = 5;
        int labelIdx = -1;
        int numPossibleLabels = -1;
        rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
        File[] fileList = new File(outputDir.getPath()).listFiles();
        int totalExamples = 0;
        for (File f2 : fileList) {
            if (!f2.getPath().endsWith(".bin"))
                continue;
            //                System.out.println(f2.getPath());
            numBinFiles++;
            DataSet ds = new DataSet();
            ds.load(f2);
            assertEquals(5, ds.numInputs());
            assertEquals(5, ds.numOutcomes());
            totalExamples += ds.numExamples();
        }
        assertEquals(150, totalExamples);
        //Expect 30, give or take due to partitioning randomness
        assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
        //Test the PortableDataStreamDataSetIterator:
        JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
        List<PortableDataStream> pdsList = pds.values().collect();
        DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
        int pdsCount = 0;
        int totalExamples2 = 0;
        while (pdsIter.hasNext()) {
            DataSet ds = pdsIter.next();
            pdsCount++;
            totalExamples2 += ds.numExamples();
            assertEquals(5, ds.numInputs());
            assertEquals(5, ds.numOutcomes());
        }
        assertEquals(150, totalExamples2);
        assertEquals(numBinFiles, pdsCount);
    } finally {
        FileUtils.deleteDirectory(temp);
    }
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) URI(java.net.URI) StringToDataSetExportFunction(org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 2 with PortableDataStreamDataSetIterator

use of org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testCsvPreprocessedDataGeneration.

@Test
public void testCsvPreprocessedDataGeneration() throws Exception {
    List<String> list = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
    }
    JavaRDD<String> rdd = sc.parallelize(list);
    int partitions = rdd.partitions().size();
    URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
    URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData2");
    File temp = new File(outputDir.getPath());
    if (temp.exists())
        FileUtils.deleteDirectory(temp);
    int numBinFiles = 0;
    try {
        int batchSize = 5;
        int labelIdx = 4;
        int numPossibleLabels = 3;
        rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
        File[] fileList = new File(outputDir.getPath()).listFiles();
        int totalExamples = 0;
        for (File f2 : fileList) {
            if (!f2.getPath().endsWith(".bin"))
                continue;
            //                System.out.println(f2.getPath());
            numBinFiles++;
            DataSet ds = new DataSet();
            ds.load(f2);
            assertEquals(4, ds.numInputs());
            assertEquals(3, ds.numOutcomes());
            totalExamples += ds.numExamples();
        }
        assertEquals(150, totalExamples);
        //Expect 30, give or take due to partitioning randomness
        assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
        //Test the PortableDataStreamDataSetIterator:
        JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
        List<PortableDataStream> pdsList = pds.values().collect();
        DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
        int pdsCount = 0;
        int totalExamples2 = 0;
        while (pdsIter.hasNext()) {
            DataSet ds = pdsIter.next();
            pdsCount++;
            totalExamples2 += ds.numExamples();
            assertEquals(4, ds.numInputs());
            assertEquals(3, ds.numOutcomes());
        }
        assertEquals(150, totalExamples2);
        assertEquals(numBinFiles, pdsCount);
    } finally {
        FileUtils.deleteDirectory(temp);
    }
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) URI(java.net.URI) StringToDataSetExportFunction(org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

File (java.io.File)2 URI (java.net.URI)2 ArrayList (java.util.ArrayList)2 PortableDataStream (org.apache.spark.input.PortableDataStream)2 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)2 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)2 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)2 StringToDataSetExportFunction (org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction)2 PortableDataStreamDataSetIterator (org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator)2 Test (org.junit.Test)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2