Search in sources :

Example 16 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class DataSetIteratorTest method testLfwIterator.

@Test
public void testLfwIterator() throws Exception {
    int numExamples = 1;
    int row = 28;
    int col = 28;
    int channels = 1;
    LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] { row, col, channels }, true);
    assertTrue(iter.hasNext());
    DataSet data = iter.next();
    assertEquals(numExamples, data.getLabels().size(0));
    assertEquals(row, data.getFeatureMatrix().size(2));
}
Also used : LFWDataSetIterator(org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) Test(org.junit.Test)

Example 17 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testNextAndReset.

@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 18 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testLoadFullDataSet.

@Test
public void testLoadFullDataSet() throws Exception {
    int epochs = 3;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(50);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertEquals(path.numExamples(), 50, 0.0);
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 19 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.

@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();
        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();
        assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
        assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) Random(java.util.Random) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 20 with DataSet

use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSV.

@Test
public void testSplittingCSV() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    //need to manually extract
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
    RecordReader rr2 = new CSVRecordReader(0, ",");
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = rrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(2, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        //Get the subsets of the original iris data
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(0));
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 2, true));
        INDArray expOut1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(3));
        INDArray expOut2 = lds;
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(expOut1, lmds[0]);
        assertEquals(expOut2, lmds[1]);
    }
    assertFalse(rrmdsi.hasNext());
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

DataSet (org.nd4j.linalg.dataset.DataSet)334 Test (org.junit.Test)226 INDArray (org.nd4j.linalg.api.ndarray.INDArray)194 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)93 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)82 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)79 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)73 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)62 ArrayList (java.util.ArrayList)50 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)41 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)38 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)34 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)32 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)31 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)31 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)25 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)24 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)24 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)23