Search in sources :

Example 16 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class DataSetIteratorTest method testIteratorDataSetIteratorSplitting.

@Test
public void testIteratorDataSetIteratorSplitting() {
    //Test splitting large data sets into smaller ones
    int origBatchSize = 4;
    int origNumDSs = 3;
    int batchSize = 3;
    int numBatches = 4;
    int featureSize = 5;
    int labelSize = 6;
    Nd4j.getRandom().setSeed(12345);
    List<DataSet> orig = new ArrayList<>();
    for (int i = 0; i < origNumDSs; i++) {
        INDArray features = Nd4j.rand(origBatchSize, featureSize);
        INDArray labels = Nd4j.rand(origBatchSize, labelSize);
        orig.add(new DataSet(features, labels));
    }
    List<DataSet> expected = new ArrayList<>();
    expected.add(new DataSet(orig.get(0).getFeatureMatrix().getRows(0, 1, 2), orig.get(0).getLabels().getRows(0, 1, 2)));
    expected.add(new DataSet(Nd4j.vstack(orig.get(0).getFeatureMatrix().getRows(3), orig.get(1).getFeatureMatrix().getRows(0, 1)), Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
    expected.add(new DataSet(Nd4j.vstack(orig.get(1).getFeatureMatrix().getRows(2, 3), orig.get(2).getFeatureMatrix().getRows(0)), Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
    expected.add(new DataSet(orig.get(2).getFeatureMatrix().getRows(1, 2, 3), orig.get(2).getLabels().getRows(1, 2, 3)));
    DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
    int count = 0;
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        assertEquals(expected.get(count), ds);
        count++;
    }
    assertEquals(count, numBatches);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) LFWDataSetIterator(org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 17 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class DataSetIteratorTest method testBatchSizeOfOneIris.

@Test
public void testBatchSizeOfOneIris() throws Exception {
    //Test for (a) iterators returning correct number of examples, and
    //(b) Labels are a proper one-hot vector (i.e., sum is 1.0)
    //Iris:
    DataSetIterator iris = new IrisDataSetIterator(1, 5);
    int irisC = 0;
    while (iris.hasNext()) {
        irisC++;
        DataSet ds = iris.next();
        assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
    }
    assertEquals(5, irisC);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) LFWDataSetIterator(org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 18 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testNextAndReset.

@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 19 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testLoadFullDataSet.

@Test
public void testLoadFullDataSet() throws Exception {
    int epochs = 3;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(50);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertEquals(path.numExamples(), 50, 0.0);
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 20 with DataSetIterator

use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.

the class ROCTest method RocEvalSanityCheck.

@Test
public void RocEvalSanityCheck() {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);
    iter.setPreProcessor(ns);
    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }
    ROCMultiClass roc = net.evaluateROCMultiClass(iter, 32);
    INDArray f = ds.getFeatures();
    INDArray l = ds.getLabels();
    INDArray out = net.output(f);
    ROCMultiClass manual = new ROCMultiClass(32);
    manual.eval(l, out);
    for (int i = 0; i < 3; i++) {
        assertEquals(manual.calculateAUC(i), roc.calculateAUC(i), 1e-6);
        double[][] rocCurve = roc.getResultsAsArray(i);
        double[][] rocManual = manual.getResultsAsArray(i);
        assertArrayEquals(rocCurve[0], rocManual[0], 1e-6);
        assertArrayEquals(rocCurve[1], rocManual[1], 1e-6);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NormalizerStandardize(org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Aggregations

DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)147 Test (org.junit.Test)133 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)90 DataSet (org.nd4j.linalg.dataset.DataSet)79 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)70 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)61 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)53 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)49 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)43 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)30 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)24 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)21 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)19 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)17 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)17 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)17 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)16 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)16 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)14