use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI_Batched.
@Test
public void testImagesRRDMSI_Batched() throws Exception {
File parentDir = Files.createTempDir();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.
@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.singletonList(new Tuple2<>(0, 0.0));
}
//Does batching where appropriate
MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
network.init();
//.value() is shared by all executors on single machine -> OK, as params are not changed in score function
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<Integer, Double>> out = new ArrayList<>();
while (iter.hasNext()) {
MultiDataSet ds = iter.next();
double score = network.score(ds, false);
int numExamples = ds.getFeatures(0).size(0);
out.add(new Tuple2<>(numExamples, score * numExamples));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return out;
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats)
s.logReturnTime();
//Sometimes: no data
return Collections.emptyList();
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
}
try {
if (stats)
s.logInitialModelBefore();
ComputationGraph net = worker.getInitialModelGraph();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
MultiDataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.getFeatures(0).size(0));
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncMultiDataSetIterator) {
((AsyncMultiDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestSparkComputationGraph method testBasic.
@Test
public void testBasic() throws Exception {
JavaSparkContext sc = this.sc;
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
List<MultiDataSet> list = new ArrayList<>(150);
while (iter.hasNext()) list.add(iter.next());
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
scg.setListeners(Collections.singleton((IterationListener) new ScoreIterationListener(1)));
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
scg.fitMultiDataSet(rdd);
//Try: fitting using DataSet
DataSetIterator iris = new IrisDataSetIterator(1, 150);
List<DataSet> list2 = new ArrayList<>();
while (iris.hasNext()) list2.add(iris.next());
JavaRDD<DataSet> rddDS = sc.parallelize(list2);
scg.fit(rddDS);
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.
@Test
public void testSplittingCSVSequence() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]);
}
assertFalse(srrmdsi.hasNext());
}
Aggregations