Search in sources :

Example 6 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI_Batched.

@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
    MultiDataSet mds = trainDataIterator.next();
    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();
    assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
    assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 7 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.

@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0, 0.0));
    }
    //Does batching where appropriate
    MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Tuple2<Integer, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        MultiDataSet ds = iter.next();
        double score = network.score(ds, false);
        int numExamples = ds.getFeatures(0).size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    return out;
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) ArrayList(java.util.ArrayList) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Tuple2(scala.Tuple2) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 8 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.

@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    WorkerConfiguration dataConfig = worker.getDataConfiguration();
    boolean stats = dataConfig.isCollectTrainingStats();
    StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
    if (stats)
        s.logMethodStartTime();
    if (!dataSetIterator.hasNext()) {
        if (stats)
            s.logReturnTime();
        //Sometimes: no data
        return Collections.emptyList();
    }
    int batchSize = dataConfig.getBatchSizePerWorker();
    final int prefetchCount = dataConfig.getPrefetchNumBatches();
    MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
    if (prefetchCount > 0) {
        batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
    }
    try {
        if (stats)
            s.logInitialModelBefore();
        ComputationGraph net = worker.getInitialModelGraph();
        if (stats)
            s.logInitialModelAfter();
        int miniBatchCount = 0;
        int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
        while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
            if (stats)
                s.logNextDataSetBefore();
            MultiDataSet next = batchedIterator.next();
            if (stats)
                s.logNextDataSetAfter(next.getFeatures(0).size(0));
            if (stats) {
                s.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
                s.logProcessMinibatchAfter();
                if (result != null) {
                    //Terminate training immediately
                    s.logReturnTime();
                    SparkTrainingStats workerStats = result.getSecond();
                    SparkTrainingStats returnStats = s.build(workerStats);
                    result.getFirst().setStats(returnStats);
                    return Collections.singletonList(result.getFirst());
                }
            } else {
                R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
                if (result != null) {
                    //Terminate training immediately
                    return Collections.singletonList(result);
                }
            }
        }
        //For some reason, we didn't return already. Normally this shouldn't happen
        if (stats) {
            s.logReturnTime();
            Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
            pair.getFirst().setStats(s.build(pair.getSecond()));
            return Collections.singletonList(pair.getFirst());
        } else {
            return Collections.singletonList(worker.getFinalResult(net));
        }
    } finally {
        //Make sure we shut down the async thread properly...
        if (batchedIterator instanceof AsyncMultiDataSetIterator) {
            ((AsyncMultiDataSetIterator) batchedIterator).shutdown();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    }
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) WorkerConfiguration(org.deeplearning4j.spark.api.WorkerConfiguration) StatsCalculationHelper(org.deeplearning4j.spark.api.stats.StatsCalculationHelper) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 9 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class TestSparkComputationGraph method testBasic.

@Test
public void testBasic() throws Exception {
    JavaSparkContext sc = this.sc;
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    List<MultiDataSet> list = new ArrayList<>(150);
    while (iter.hasNext()) list.add(iter.next());
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
    scg.setListeners(Collections.singleton((IterationListener) new ScoreIterationListener(1)));
    JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
    scg.fitMultiDataSet(rdd);
    //Try: fitting using DataSet
    DataSetIterator iris = new IrisDataSetIterator(1, 150);
    List<DataSet> list2 = new ArrayList<>();
    while (iris.hasNext()) list2.add(iris.next());
    JavaRDD<DataSet> rddDS = sc.parallelize(list2);
    scg.fit(rddDS);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 10 with MultiDataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.

@Test
public void testSplittingCSVSequence() throws Exception {
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
    }
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    resource = new ClassPathResource("csvsequencelabels_0.txt");
    String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Test(org.junit.Test)

Aggregations

MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)17 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)10 Test (org.junit.Test)9 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)8 FileSplit (org.datavec.api.split.FileSplit)7 DataSet (org.nd4j.linalg.dataset.DataSet)7 AsyncMultiDataSetIterator (org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)6 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)6 RecordReader (org.datavec.api.records.reader.RecordReader)5 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)5 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)4 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 IteratorMultiDataSetIterator (org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator)3 ParentPathLabelGenerator (org.datavec.api.io.labels.ParentPathLabelGenerator)2 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)2 RecordReaderMultiDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator)2 SingletonMultiDataSetIterator (org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator)2