Search in sources :

Example 21 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testCSVLoadingRegression.

@Test
public void testCSVLoadingRegression() throws Exception {
    int nLines = 30;
    int nFeatures = 5;
    int miniBatchSize = 10;
    int labelIdx = 0;
    String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "rr_csv_test_rand.csv");
    double[][] data = makeRandomCSV(path, nLines, nFeatures);
    RecordReader testReader = new CSVRecordReader();
    testReader.initialize(new FileSplit(new File(path)));
    DataSetIterator iter = new RecordReaderDataSetIterator(testReader, null, miniBatchSize, labelIdx, 1, true);
    int miniBatch = 0;
    while (iter.hasNext()) {
        DataSet test = iter.next();
        INDArray features = test.getFeatureMatrix();
        INDArray labels = test.getLabels();
        assertArrayEquals(new int[] { miniBatchSize, nFeatures }, features.shape());
        assertArrayEquals(new int[] { miniBatchSize, 1 }, labels.shape());
        int startRow = miniBatch * miniBatchSize;
        for (int i = 0; i < miniBatchSize; i++) {
            double labelExp = data[startRow + i][labelIdx];
            double labelAct = labels.getDouble(i);
            assertEquals(labelExp, labelAct, 1e-5f);
            int featureCount = 0;
            for (int j = 0; j < nFeatures + 1; j++) {
                if (j == labelIdx)
                    continue;
                double featureExp = data[startRow + i][j];
                double featureAct = features.getDouble(i, featureCount++);
                assertEquals(featureExp, featureAct, 1e-5f);
            }
        }
        miniBatch++;
    }
    assertEquals(nLines / miniBatchSize, miniBatch);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 22 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class MultipleEpochsIteratorTest method testLoadBatchDataSet.

@Test
public void testLoadBatchDataSet() throws Exception {
    int epochs = 2;
    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(20);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next(10);
        assertEquals(path.numExamples(), 10, 0.0);
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 23 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class ConvolutionLayerSetupTest method testLRN.

@Test
public void testLRN() throws Exception {
    List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
    String rootDir = new ClassPathResource("lfwtest").getFile().getAbsolutePath();
    RecordReader reader = new ImageRecordReader(28, 28, 3);
    reader.initialize(new FileSplit(new File(rootDir)));
    DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
    labels.remove("lfwtest");
    NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
    builder.setInputType(InputType.convolutional(28, 28, 3));
    MultiLayerConfiguration conf = builder.build();
    ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
    assertEquals(6, layer2.getNIn());
}
Also used : RecordReader(org.datavec.api.records.reader.RecordReader) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) ArrayList(java.util.ArrayList) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.datavec.api.util.ClassPathResource) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) File(java.io.File) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) ImageRecordReader(org.datavec.image.recordreader.ImageRecordReader) Test(org.junit.Test)

Example 24 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class TestComputationGraphNetwork method testIrisFitMultiDataSetIterator.

@Test
public void testIrisFitMultiDataSetIterator() throws Exception {
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    cg.fit(iter);
    rr.reset();
    iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    while (iter.hasNext()) {
        cg.fit(iter.next());
    }
}
Also used : RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 25 with FileSplit

use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method buildGraph.

private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
    File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
    CSVRecordReader reader = new CSVRecordReader(0, ",");
    reader.initialize(new FileSplit(nodes));
    List<Blogger> bloggers = new ArrayList<>();
    int cnt = 0;
    while (reader.hasNext()) {
        List<Writable> lines = new ArrayList<>(reader.next());
        Blogger blogger = new Blogger(lines.get(0).toInt());
        bloggers.add(blogger);
        cnt++;
    }
    reader.close();
    Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
    // load edges
    File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
    reader = new CSVRecordReader(0, ",");
    reader.initialize(new FileSplit(edges));
    while (reader.hasNext()) {
        List<Writable> lines = new ArrayList<>(reader.next());
        int from = lines.get(0).toInt();
        int to = lines.get(1).toInt();
        graph.addEdge(from - 1, to - 1, 1.0, false);
    }
    logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
    logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
    logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
    assertEquals(119, graph.getConnectedVertices(0).size());
    assertEquals(9, graph.getConnectedVertices(1).size());
    assertEquals(6, graph.getConnectedVertices(3).size());
    return graph;
}
Also used : Graph(org.deeplearning4j.models.sequencevectors.graph.primitives.Graph) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) ArrayList(java.util.ArrayList) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) File(java.io.File)

Aggregations

FileSplit (org.datavec.api.split.FileSplit)26 Test (org.junit.Test)25 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)18 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)18 RecordReader (org.datavec.api.records.reader.RecordReader)17 DataSet (org.nd4j.linalg.dataset.DataSet)16 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)14 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)14 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)11 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)9 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)8 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)7 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)7 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)5 File (java.io.File)4 ArrayList (java.util.ArrayList)4 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 ClassPathResource (org.datavec.api.util.ClassPathResource)4