Search in sources :

Example 11 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReader.

@Test
public void testRecordReader() throws Exception {
    RecordReader recordReader = new CSVRecordReader();
    FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
    recordReader.initialize(csv);
    DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
    DataSet next = iter.next();
    assertEquals(34, next.numExamples());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 12 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testSequenceRecordReaderRegression.

@Test
public void testSequenceRecordReaderRegression() throws Exception {
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
    }
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    resource = new ClassPathResource("csvsequence_0.txt");
    String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 0, true);
    assertEquals(3, iter.inputColumns());
    assertEquals(3, iter.totalOutcomes());
    List<DataSet> dsList = new ArrayList<>();
    while (iter.hasNext()) {
        dsList.add(iter.next());
    }
    //3 files
    assertEquals(3, dsList.size());
    for (int i = 0; i < 3; i++) {
        DataSet ds = dsList.get(i);
        INDArray features = ds.getFeatureMatrix();
        INDArray labels = ds.getLabels();
        //1 examples, 3 values, 4 time steps
        assertArrayEquals(new int[] { 1, 3, 4 }, features.shape());
        assertArrayEquals(new int[] { 1, 3, 4 }, labels.shape());
        assertEquals(features, labels);
    }
    //Also test regression + reset from a single reader:
    featureReader.reset();
    iter = new SequenceRecordReaderDataSetIterator(featureReader, 1, 0, 2, true);
    int count = 0;
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        assertEquals(2, ds.getFeatureMatrix().size(1));
        assertEquals(1, ds.getLabels().size(1));
        count++;
    }
    assertEquals(3, count);
    iter.reset();
    count = 0;
    while (iter.hasNext()) {
        iter.next();
        count++;
    }
    assertEquals(3, count);
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) Test(org.junit.Test)

Example 13 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class DataSetIteratorTest method testMnist.

@Test
public void testMnist() throws Exception {
    ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
    CSVRecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
    RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
    MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
    while (dsi.hasNext()) {
        DataSet dsExp = dsi.next();
        DataSet dsAct = iter.next();
        INDArray fExp = dsExp.getFeatureMatrix();
        fExp.divi(255);
        INDArray lExp = dsExp.getLabels();
        INDArray fAct = dsAct.getFeatureMatrix();
        INDArray lAct = dsAct.getLabels();
        assertEquals(fExp, fAct);
        assertEquals(lExp, lAct);
    }
    assertFalse(iter.hasNext());
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 14 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.

@Test
public void testImagesRRDMSI() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    Random r = new Random(12345);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
    for (int i = 0; i < 2; i++) {
        MultiDataSet mds = trainDataIterator.next();
        DataSet d1 = dsi1.next();
        DataSet d2 = dsi2.next();
        assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
        assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
        assertEquals(d1.getLabels(), mds.getLabels(0));
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) Random(java.util.Random) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 15 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVMeta.

@Test
public void testSplittingCSVMeta() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    RecordReader rr2 = new CSVRecordReader(0, ",");
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    rrmdsi.setCollectMetaData(true);
    int count = 0;
    while (rrmdsi.hasNext()) {
        MultiDataSet mds = rrmdsi.next();
        MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(150 / 10, count);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

ClassPathResource (org.nd4j.linalg.io.ClassPathResource)64 Test (org.junit.Test)63 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)23 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 DataSet (org.nd4j.linalg.dataset.DataSet)23 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)20 FileSplit (org.datavec.api.split.FileSplit)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)18 File (java.io.File)17 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)14 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)13 RecordReader (org.datavec.api.records.reader.RecordReader)12 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)12 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)12 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)11 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)8 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)8 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)7 LossMCXENT (org.nd4j.linalg.lossfunctions.impl.LossMCXENT)7 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)6