Search in sources :

Example 21 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testCsvPreprocessedDataGeneration.

@Test
public void testCsvPreprocessedDataGeneration() throws Exception {
    List<String> list = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
    }
    JavaRDD<String> rdd = sc.parallelize(list);
    int partitions = rdd.partitions().size();
    URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
    URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData2");
    File temp = new File(outputDir.getPath());
    if (temp.exists())
        FileUtils.deleteDirectory(temp);
    int numBinFiles = 0;
    try {
        int batchSize = 5;
        int labelIdx = 4;
        int numPossibleLabels = 3;
        rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
        File[] fileList = new File(outputDir.getPath()).listFiles();
        int totalExamples = 0;
        for (File f2 : fileList) {
            if (!f2.getPath().endsWith(".bin"))
                continue;
            //                System.out.println(f2.getPath());
            numBinFiles++;
            DataSet ds = new DataSet();
            ds.load(f2);
            assertEquals(4, ds.numInputs());
            assertEquals(3, ds.numOutcomes());
            totalExamples += ds.numExamples();
        }
        assertEquals(150, totalExamples);
        //Expect 30, give or take due to partitioning randomness
        assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
        //Test the PortableDataStreamDataSetIterator:
        JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
        List<PortableDataStream> pdsList = pds.values().collect();
        DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
        int pdsCount = 0;
        int totalExamples2 = 0;
        while (pdsIter.hasNext()) {
            DataSet ds = pdsIter.next();
            pdsCount++;
            totalExamples2 += ds.numExamples();
            assertEquals(4, ds.numInputs());
            assertEquals(3, ds.numOutcomes());
        }
        assertEquals(150, totalExamples2);
        assertEquals(numBinFiles, pdsCount);
    } finally {
        FileUtils.deleteDirectory(temp);
    }
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) URI(java.net.URI) StringToDataSetExportFunction(org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 22 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecDataSetFunctionMultiLabelRegression.

@Test
public void testDataVecDataSetFunctionMultiLabelRegression() throws Exception {
    JavaSparkContext sc = getContext();
    List<String> stringData = new ArrayList<>();
    int n = 6;
    for (int i = 0; i < 10; i++) {
        StringBuilder sb = new StringBuilder();
        boolean first = true;
        for (int j = 0; j < n; j++) {
            if (!first)
                sb.append(",");
            sb.append(10 * i + j);
            first = false;
        }
        stringData.add(sb.toString());
    }
    JavaRDD<String> stringList = sc.parallelize(stringData);
    JavaRDD<List<Writable>> writables = stringList.map(new StringToWritablesFunction(new CSVRecordReader()));
    JavaRDD<DataSet> dataSets = writables.map(new DataVecDataSetFunction(3, 5, -1, true, null, null));
    List<DataSet> ds = dataSets.collect();
    assertEquals(10, ds.size());
    boolean[] seen = new boolean[10];
    for (DataSet d : ds) {
        INDArray f = d.getFeatureMatrix();
        INDArray l = d.getLabels();
        assertEquals(3, f.length());
        assertEquals(3, l.length());
        int exampleIdx = ((int) f.getDouble(0)) / 10;
        seen[exampleIdx] = true;
        for (int j = 0; j < 3; j++) {
            assertEquals(10 * exampleIdx + j, (int) f.getDouble(j));
            assertEquals(10 * exampleIdx + j + 3, (int) l.getDouble(j));
        }
    }
    int seenCount = 0;
    for (boolean b : seen) if (b)
        seenCount++;
    assertEquals(10, seenCount);
}
Also used : StringToWritablesFunction(org.datavec.spark.transform.misc.StringToWritablesFunction) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)22 Test (org.junit.Test)20 FileSplit (org.datavec.api.split.FileSplit)19 RecordReader (org.datavec.api.records.reader.RecordReader)17 DataSet (org.nd4j.linalg.dataset.DataSet)16 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)13 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)11 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)10 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)10 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)6 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)5 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)5 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)5 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)5 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)5 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)4