Search in sources :

Example 11 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSV.

@Test
public void testSplittingCSV() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    //need to manually extract
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
    RecordReader rr2 = new CSVRecordReader(0, ",");
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = rrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(2, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        //Get the subsets of the original iris data
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(0));
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 2, true));
        INDArray expOut1 = fds.get(NDArrayIndex.all(), NDArrayIndex.point(3));
        INDArray expOut2 = lds;
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(expOut1, lmds[0]);
        assertEquals(expOut2, lmds[1]);
    }
    assertFalse(rrmdsi.hasNext());
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 12 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI_Batched.

@Test
public void testImagesRRDMSI_Batched() throws Exception {
    File parentDir = Files.createTempDir();
    parentDir.deleteOnExit();
    String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
    String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
    File f1 = new File(str1);
    File f2 = new File(str2);
    f1.mkdirs();
    f2.mkdirs();
    writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
    writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
    int outputNum = 2;
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
    ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1.initialize(new FileSplit(parentDir));
    rr1s.initialize(new FileSplit(parentDir));
    MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
    //Now, do the same thing with ImageRecordReader, and check we get the same results:
    ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
    ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
    rr1_b.initialize(new FileSplit(parentDir));
    rr1s_b.initialize(new FileSplit(parentDir));
    DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
    DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
    MultiDataSet mds = trainDataIterator.next();
    DataSet d1 = dsi1.next();
    DataSet d2 = dsi2.next();
    assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
    assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
    assertEquals(d1.getLabels(), mds.getLabels(0));
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 13 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunction.

@Test
public void testDataVecDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/imagetest/0/a.bmp");
    //Need this for Spark: can't infer without init call
    List<String> labelsList = Arrays.asList("0", "1");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 7);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //4 images
    assertEquals(4, origData.count());
    ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    rr.setLabels(labelsList);
    org.datavec.spark.functions.RecordReaderFunction rrf = new org.datavec.spark.functions.RecordReaderFunction(rr);
    JavaRDD<List<Writable>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecDataSetFunction(1, 2, false));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "bmp" }, true);
    ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
    irr.initialize(is);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(irr, 1, 1, 2);
    List<DataSet> listLocal = new ArrayList<>(4);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(4, collected.size());
    assertEquals(4, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[4];
    for (int i = 0; i < 4; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 4; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 4 and exactly 4 pairwise matches between spark and local versions
    assertEquals(4, count);
}
Also used : SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) ParentPathLabelGenerator(org.datavec.api.io.labels.ParentPathLabelGenerator) File(java.io.File) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 14 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class TestSparkComputationGraph method testBasic.

@Test
public void testBasic() throws Exception {
    JavaSparkContext sc = this.sc;
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    List<MultiDataSet> list = new ArrayList<>(150);
    while (iter.hasNext()) list.add(iter.next());
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
    scg.setListeners(Collections.singleton((IterationListener) new ScoreIterationListener(1)));
    JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
    scg.fitMultiDataSet(rdd);
    //Try: fitting using DataSet
    DataSetIterator iris = new IrisDataSetIterator(1, 150);
    List<DataSet> list2 = new ArrayList<>();
    while (iris.hasNext()) list2.add(iris.next());
    JavaRDD<DataSet> rddDS = sc.parallelize(list2);
    scg.fit(rddDS);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 15 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testInputValidation.

@Test
public void testInputValidation() {
    //Test: no readers
    try {
        MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something").addOutput("something").build();
        fail("Should have thrown exception");
    } catch (Exception e) {
    }
    //Test: reference to reader that doesn't exist
    try {
        RecordReader rr = new CSVRecordReader(0, ",");
        rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
        MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
        fail("Should have thrown exception");
    } catch (Exception e) {
    }
    //Test: no inputs or outputs
    try {
        RecordReader rr = new CSVRecordReader(0, ",");
        rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
        MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build();
        fail("Should have thrown exception");
    } catch (Exception e) {
    }
}
Also used : MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

FileSplit (org.datavec.api.split.FileSplit)26 Test (org.junit.Test)25 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)18 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)18 RecordReader (org.datavec.api.records.reader.RecordReader)17 DataSet (org.nd4j.linalg.dataset.DataSet)16 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)14 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)14 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)11 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)9 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)8 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)7 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)7 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)5 File (java.io.File)4 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4