Search in sources :

Example 16 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testCSVLoadingRegression.

@Test
public void testCSVLoadingRegression() throws Exception {
    int nLines = 30;
    int nFeatures = 5;
    int miniBatchSize = 10;
    int labelIdx = 0;
    String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "rr_csv_test_rand.csv");
    double[][] data = makeRandomCSV(path, nLines, nFeatures);
    RecordReader testReader = new CSVRecordReader();
    testReader.initialize(new FileSplit(new File(path)));
    DataSetIterator iter = new RecordReaderDataSetIterator(testReader, null, miniBatchSize, labelIdx, 1, true);
    int miniBatch = 0;
    while (iter.hasNext()) {
        DataSet test = iter.next();
        INDArray features = test.getFeatureMatrix();
        INDArray labels = test.getLabels();
        assertArrayEquals(new int[] { miniBatchSize, nFeatures }, features.shape());
        assertArrayEquals(new int[] { miniBatchSize, 1 }, labels.shape());
        int startRow = miniBatch * miniBatchSize;
        for (int i = 0; i < miniBatchSize; i++) {
            double labelExp = data[startRow + i][labelIdx];
            double labelAct = labels.getDouble(i);
            assertEquals(labelExp, labelAct, 1e-5f);
            int featureCount = 0;
            for (int j = 0; j < nFeatures + 1; j++) {
                if (j == labelIdx)
                    continue;
                double featureExp = data[startRow + i][j];
                double featureAct = features.getDouble(i, featureCount++);
                assertEquals(featureExp, featureAct, 1e-5f);
            }
        }
        miniBatch++;
    }
    assertEquals(nLines / miniBatchSize, miniBatch);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 17 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testLoadBatchDataSet.

@Test
public void testLoadBatchDataSet() throws Exception {
    int epochs = 2;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(20);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next(10);
        assertEquals(path.numExamples(), 10, 0.0);
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 18 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class TestComputationGraphNetwork method testIrisFitMultiDataSetIterator.

@Test
public void testIrisFitMultiDataSetIterator() throws Exception {
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    cg.fit(iter);
    rr.reset();
    iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    while (iter.hasNext()) {
        cg.fit(iter.next());
    }
}
Also used : RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 19 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method buildGraph.

private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
    File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
    CSVRecordReader reader = new CSVRecordReader(0, ",");
    reader.initialize(new FileSplit(nodes));
    List<Blogger> bloggers = new ArrayList<>();
    int cnt = 0;
    while (reader.hasNext()) {
        List<Writable> lines = new ArrayList<>(reader.next());
        Blogger blogger = new Blogger(lines.get(0).toInt());
        bloggers.add(blogger);
        cnt++;
    }
    reader.close();
    Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
    // load edges
    File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
    reader = new CSVRecordReader(0, ",");
    reader.initialize(new FileSplit(edges));
    while (reader.hasNext()) {
        List<Writable> lines = new ArrayList<>(reader.next());
        int from = lines.get(0).toInt();
        int to = lines.get(1).toInt();
        graph.addEdge(from - 1, to - 1, 1.0, false);
    }
    logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
    logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
    logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
    assertEquals(119, graph.getConnectedVertices(0).size());
    assertEquals(9, graph.getConnectedVertices(1).size());
    assertEquals(6, graph.getConnectedVertices(3).size());
    return graph;
}
Also used : Graph(org.deeplearning4j.models.sequencevectors.graph.primitives.Graph) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) ArrayList(java.util.ArrayList) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) File(java.io.File)

Example 20 with CSVRecordReader

use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testCsvPreprocessedDataGenerationNoLabel.

@Test
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
    //Same as above test, but without any labels (in which case: input and output arrays are the same)
    List<String> list = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
    }
    JavaRDD<String> rdd = sc.parallelize(list);
    int partitions = rdd.partitions().size();
    URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
    URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData3");
    File temp = new File(outputDir.getPath());
    if (temp.exists())
        FileUtils.deleteDirectory(temp);
    int numBinFiles = 0;
    try {
        int batchSize = 5;
        int labelIdx = -1;
        int numPossibleLabels = -1;
        rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
        File[] fileList = new File(outputDir.getPath()).listFiles();
        int totalExamples = 0;
        for (File f2 : fileList) {
            if (!f2.getPath().endsWith(".bin"))
                continue;
            //                System.out.println(f2.getPath());
            numBinFiles++;
            DataSet ds = new DataSet();
            ds.load(f2);
            assertEquals(5, ds.numInputs());
            assertEquals(5, ds.numOutcomes());
            totalExamples += ds.numExamples();
        }
        assertEquals(150, totalExamples);
        //Expect 30, give or take due to partitioning randomness
        assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
        //Test the PortableDataStreamDataSetIterator:
        JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
        List<PortableDataStream> pdsList = pds.values().collect();
        DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
        int pdsCount = 0;
        int totalExamples2 = 0;
        while (pdsIter.hasNext()) {
            DataSet ds = pdsIter.next();
            pdsCount++;
            totalExamples2 += ds.numExamples();
            assertEquals(5, ds.numInputs());
            assertEquals(5, ds.numOutcomes());
        }
        assertEquals(150, totalExamples2);
        assertEquals(numBinFiles, pdsCount);
    } finally {
        FileUtils.deleteDirectory(temp);
    }
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) URI(java.net.URI) StringToDataSetExportFunction(org.deeplearning4j.spark.datavec.export.StringToDataSetExportFunction) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)22 Test (org.junit.Test)20 FileSplit (org.datavec.api.split.FileSplit)19 RecordReader (org.datavec.api.records.reader.RecordReader)17 DataSet (org.nd4j.linalg.dataset.DataSet)16 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)13 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)11 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)10 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)10 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)6 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)5 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)5 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)5 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)5 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)5 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)4