Search in sources :

Example 1 with BatchAndExportMultiDataSetsFunction

use of org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction in project deeplearning4j by deeplearning4j.

the class TestExport method testBatchAndExportMultiDataSetsFunction.

@Test
public void testBatchAndExportMultiDataSetsFunction() throws Exception {
    String baseDir = System.getProperty("java.io.tmpdir");
    baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
    baseDir = baseDir.replaceAll("\\\\", "/");
    File f = new File(baseDir);
    if (f.exists())
        FileUtils.deleteDirectory(f);
    f.mkdir();
    f.deleteOnExit();
    int minibatchSize = 5;
    int nIn = 4;
    int nOut = 3;
    List<MultiDataSet> dataSets = new ArrayList<>();
    //Larger than minibatch size -> tests splitting
    dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(10, nIn), Nd4j.create(10, nOut)));
    for (int i = 0; i < 98; i++) {
        if (i % 2 == 0) {
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(5, nIn), Nd4j.create(5, nOut)));
        } else {
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
            dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(3, nIn), Nd4j.create(3, nOut)));
        }
    }
    Collections.shuffle(dataSets, new Random(12345));
    JavaRDD<MultiDataSet> rdd = sc.parallelize(dataSets);
    //For testing purposes (should get exactly 100 out, but maybe more with more partitions)
    rdd = rdd.repartition(1);
    JavaRDD<String> pathsRdd = rdd.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(minibatchSize, "file:///" + baseDir), true);
    List<String> paths = pathsRdd.collect();
    assertEquals(100, paths.size());
    File[] files = f.listFiles();
    assertNotNull(files);
    int count = 0;
    for (File file : files) {
        if (!file.getPath().endsWith(".bin"))
            continue;
        System.out.println(file);
        MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
        ds.load(file);
        assertEquals(minibatchSize, ds.getFeatures(0).size(0));
        assertEquals(minibatchSize, ds.getLabels(0).size(0));
        count++;
    }
    assertEquals(100, count);
    FileUtils.deleteDirectory(f);
}
Also used : BatchAndExportMultiDataSetsFunction(org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction) ArrayList(java.util.ArrayList) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Random(java.util.Random) File(java.io.File) Test(org.junit.Test) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest)

Example 2 with BatchAndExportMultiDataSetsFunction

use of org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method exportMDS.

private String exportMDS(JavaRDD<MultiDataSet> trainingData) {
    String baseDir = getBaseDirForRDD(trainingData);
    String dataDir = baseDir + "data/";
    String pathsDir = baseDir + "paths/";
    log.info("Initiating RDD<MultiDataSet> export at {}", baseDir);
    JavaRDD<String> paths = trainingData.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(batchSizePerWorker, dataDir), true);
    paths.saveAsTextFile(pathsDir);
    log.info("RDD<MultiDataSet> export complete at {}", baseDir);
    lastExportedRDDId = trainingData.id();
    lastRDDExportPath = baseDir;
    return baseDir;
}
Also used : BatchAndExportMultiDataSetsFunction(org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction)

Aggregations

BatchAndExportMultiDataSetsFunction (org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction)2 File (java.io.File)1 ArrayList (java.util.ArrayList)1 Random (java.util.Random)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)1