use of org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestPreProcessedData method testCsvPreprocessedDataGenerationNoLabel.
@Test
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
//Same as above test, but without any labels (in which case: input and output arrays are the same)
List<String> list = new ArrayList<>();
DataSetIterator iter = new IrisDataSetIterator(1, 150);
while (iter.hasNext()) {
DataSet ds = iter.next();
list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
}
JavaRDD<String> rdd = sc.parallelize(list);
int partitions = rdd.partitions().size();
URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData3");
File temp = new File(outputDir.getPath());
if (temp.exists())
FileUtils.deleteDirectory(temp);
int numBinFiles = 0;
try {
int batchSize = 5;
int labelIdx = -1;
int numPossibleLabels = -1;
rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
File[] fileList = new File(outputDir.getPath()).listFiles();
int totalExamples = 0;
for (File f2 : fileList) {
if (!f2.getPath().endsWith(".bin"))
continue;
// System.out.println(f2.getPath());
numBinFiles++;
DataSet ds = new DataSet();
ds.load(f2);
assertEquals(5, ds.numInputs());
assertEquals(5, ds.numOutcomes());
totalExamples += ds.numExamples();
}
assertEquals(150, totalExamples);
//Expect 30, give or take due to partitioning randomness
assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
//Test the PortableDataStreamDataSetIterator:
JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
List<PortableDataStream> pdsList = pds.values().collect();
DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
int pdsCount = 0;
int totalExamples2 = 0;
while (pdsIter.hasNext()) {
DataSet ds = pdsIter.next();
pdsCount++;
totalExamples2 += ds.numExamples();
assertEquals(5, ds.numInputs());
assertEquals(5, ds.numOutcomes());
}
assertEquals(150, totalExamples2);
assertEquals(numBinFiles, pdsCount);
} finally {
FileUtils.deleteDirectory(temp);
}
}
use of org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestPreProcessedData method testCsvPreprocessedDataGeneration.
@Test
public void testCsvPreprocessedDataGeneration() throws Exception {
List<String> list = new ArrayList<>();
DataSetIterator iter = new IrisDataSetIterator(1, 150);
while (iter.hasNext()) {
DataSet ds = iter.next();
list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
}
JavaRDD<String> rdd = sc.parallelize(list);
int partitions = rdd.partitions().size();
URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData2");
File temp = new File(outputDir.getPath());
if (temp.exists())
FileUtils.deleteDirectory(temp);
int numBinFiles = 0;
try {
int batchSize = 5;
int labelIdx = 4;
int numPossibleLabels = 3;
rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
File[] fileList = new File(outputDir.getPath()).listFiles();
int totalExamples = 0;
for (File f2 : fileList) {
if (!f2.getPath().endsWith(".bin"))
continue;
// System.out.println(f2.getPath());
numBinFiles++;
DataSet ds = new DataSet();
ds.load(f2);
assertEquals(4, ds.numInputs());
assertEquals(3, ds.numOutcomes());
totalExamples += ds.numExamples();
}
assertEquals(150, totalExamples);
//Expect 30, give or take due to partitioning randomness
assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
//Test the PortableDataStreamDataSetIterator:
JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
List<PortableDataStream> pdsList = pds.values().collect();
DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
int pdsCount = 0;
int totalExamples2 = 0;
while (pdsIter.hasNext()) {
DataSet ds = pdsIter.next();
pdsCount++;
totalExamples2 += ds.numExamples();
assertEquals(4, ds.numInputs());
assertEquals(3, ds.numOutcomes());
}
assertEquals(150, totalExamples2);
assertEquals(numBinFiles, pdsCount);
} finally {
FileUtils.deleteDirectory(temp);
}
}
Aggregations