use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSV.
@Test
public void testSplittingCSV() throws Exception {
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Inputs: columns 0 and 1-2
//Outputs: columns 3, and 4->OneHot
//need to manually extract
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ",");
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(2, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
//Get the subsets of the original iris data
INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(0));
INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 2, true));
INDArray expOut1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(3));
INDArray expOut2 = lds;
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(expOut1, lmds[0]);
assertEquals(expOut2, lmds[1]);
}
assertFalse(rrmdsi.hasNext());
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI_Batched.
@Test
public void testImagesRRDMSI_Batched() throws Exception {
File parentDir = Files.createTempDir();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingMaster method executeTraining.
@Override
public void executeTraining(SparkComputationGraph graph, JavaRDD<DataSet> trainingData) {
if (numWorkers == null)
numWorkers = graph.getSparkContext().defaultParallelism();
JavaRDD<MultiDataSet> mdsTrainingData = trainingData.map(new DataSetToMultiDataSetFn());
executeTrainingMDS(graph, mdsTrainingData);
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class TestExport method testBatchAndExportMultiDataSetsFunction.
@Test
public void testBatchAndExportMultiDataSetsFunction() throws Exception {
String baseDir = System.getProperty("java.io.tmpdir");
baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
baseDir = baseDir.replaceAll("\\\\", "/");
File f = new File(baseDir);
if (f.exists())
FileUtils.deleteDirectory(f);
f.mkdir();
f.deleteOnExit();
int minibatchSize = 5;
int nIn = 4;
int nOut = 3;
List<MultiDataSet> dataSets = new ArrayList<>();
//Larger than minibatch size -> tests splitting
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(10, nIn), Nd4j.create(10, nOut)));
for (int i = 0; i < 98; i++) {
if (i % 2 == 0) {
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(5, nIn), Nd4j.create(5, nOut)));
} else {
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(3, nIn), Nd4j.create(3, nOut)));
}
}
Collections.shuffle(dataSets, new Random(12345));
JavaRDD<MultiDataSet> rdd = sc.parallelize(dataSets);
//For testing purposes (should get exactly 100 out, but maybe more with more partitions)
rdd = rdd.repartition(1);
JavaRDD<String> pathsRdd = rdd.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(minibatchSize, "file:///" + baseDir), true);
List<String> paths = pathsRdd.collect();
assertEquals(100, paths.size());
File[] files = f.listFiles();
assertNotNull(files);
int count = 0;
for (File file : files) {
if (!file.getPath().endsWith(".bin"))
continue;
System.out.println(file);
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
ds.load(file);
assertEquals(minibatchSize, ds.getFeatures(0).size(0));
assertEquals(minibatchSize, ds.getLabels(0).size(0));
count++;
}
assertEquals(100, count);
FileUtils.deleteDirectory(f);
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ScoreExamplesFunctionAdapter method call.
@Override
public Iterable<Double> call(Iterator<MultiDataSet> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Double> ret = new ArrayList<>();
List<MultiDataSet> collect = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
MultiDataSet ds = iterator.next();
int n = ds.getFeatures(0).size(0);
collect.add(ds);
nExamples += n;
}
totalCount += nExamples;
MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (double doubleScore : doubleScores) {
ret.add(doubleScore);
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
Aggregations