use of org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction in project deeplearning4j by deeplearning4j.
the class TestExport method testBatchAndExportMultiDataSetsFunction.
@Test
public void testBatchAndExportMultiDataSetsFunction() throws Exception {
String baseDir = System.getProperty("java.io.tmpdir");
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
baseDir = baseDir.replaceAll("\\\\", "/");
File f = new File(baseDir);
if (f.exists())
FileUtils.deleteDirectory(f);
f.mkdir();
f.deleteOnExit();
int minibatchSize = 5;
int nIn = 4;
int nOut = 3;
List<MultiDataSet> dataSets = new ArrayList<>();
//Larger than minibatch size -> tests splitting
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(10, nIn), Nd4j.create(10, nOut)));
for (int i = 0; i < 98; i++) {
if (i % 2 == 0) {
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(5, nIn), Nd4j.create(5, nOut)));
} else {
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(3, nIn), Nd4j.create(3, nOut)));
}
}
Collections.shuffle(dataSets, new Random(12345));
JavaRDD<MultiDataSet> rdd = sc.parallelize(dataSets);
//For testing purposes (should get exactly 100 out, but maybe more with more partitions)
rdd = rdd.repartition(1);
JavaRDD<String> pathsRdd = rdd.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(minibatchSize, "file:///" + baseDir), true);
List<String> paths = pathsRdd.collect();
assertEquals(100, paths.size());
File[] files = f.listFiles();
assertNotNull(files);
int count = 0;
for (File file : files) {
if (!file.getPath().endsWith(".bin"))
continue;
System.out.println(file);
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
ds.load(file);
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
assertEquals(minibatchSize, ds.getLabels(0).size(0));
count++;
}
assertEquals(100, count);
FileUtils.deleteDirectory(f);
}
use of org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingMaster method exportMDS.
private String exportMDS(JavaRDD<MultiDataSet> trainingData) {
String baseDir = getBaseDirForRDD(trainingData);
String dataDir = baseDir + "data/";
String pathsDir = baseDir + "paths/";
log.info("Initiating RDD<MultiDataSet> export at {}", baseDir);
JavaRDD<String> paths = trainingData.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(batchSizePerWorker, dataDir), true);
paths.saveAsTextFile(pathsDir);
log.info("RDD<MultiDataSet> export complete at {}", baseDir);
lastExportedRDDId = trainingData.id();
lastRDDExportPath = baseDir;
return baseDir;
}
Aggregations