Search in sources :

Example 1 with CollectionRecordReader

use of org.datavec.api.records.reader.impl.collection.CollectionRecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReaderDataSetIteratorNDArrayWritableLabels.

@Test
public void testRecordReaderDataSetIteratorNDArrayWritableLabels() {
    Collection<Collection<Writable>> data = new ArrayList<>();
    data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
    data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
    data.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
    RecordReader rr = new CollectionRecordReader(data);
    int batchSize = 3;
    int labelIndexFrom = 2;
    int labelIndexTo = 2;
    boolean regression = true;
    DataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
    DataSet ds = rrdsi.next();
    INDArray expFeatures = Nd4j.create(new double[][] { { 0, 1 }, { 2, 3 }, { 4, 5 } });
    INDArray expLabels = Nd4j.create(new double[][] { { 1.1, 2.1, 3.1 }, { 4.1, 5.1, 6.1 }, { 7.1, 8.1, 9.1 } });
    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
    //ALSO: test if we have NDArrayWritables for BOTH the features and the labels
    data = new ArrayList<>();
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 0, 1 })), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 2, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5 })), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
    rr = new CollectionRecordReader(data);
    rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
    ds = rrdsi.next();
    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 2 with CollectionRecordReader

use of org.datavec.api.records.reader.impl.collection.CollectionRecordReader in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRDSI.

@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    CollectionRecordReader crr = new CollectionRecordReader(c);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);
    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRange(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) Collection(java.util.Collection) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) DL4JException(org.deeplearning4j.exception.DL4JException) IntWritable(org.datavec.api.writable.IntWritable) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 3 with CollectionRecordReader

use of org.datavec.api.records.reader.impl.collection.CollectionRecordReader in project deeplearning4j by deeplearning4j.

the class StringToDataSetExportFunction method processBatchIfRequired.

private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception {
    if (list.isEmpty())
        return;
    if (list.size() < batchSize && !finalRecord)
        return;
    RecordReader rr = new CollectionRecordReader(list);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, new SelfWritableConverter(), batchSize, labelIndex, numPossibleLabels, regression);
    DataSet ds = iter.next();
    String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";
    URI uri = new URI(outputDir.getPath() + "/" + filename);
    FileSystem file = FileSystem.get(uri, conf);
    try (FSDataOutputStream out = file.create(new Path(uri))) {
        ds.save(out);
    }
    list.clear();
}
Also used : Path(org.apache.hadoop.fs.Path) SelfWritableConverter(org.datavec.api.io.converters.SelfWritableConverter) DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) FileSystem(org.apache.hadoop.fs.FileSystem) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) FSDataOutputStream(org.apache.hadoop.fs.FSDataOutputStream) URI(java.net.URI)

Aggregations

CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)3 RecordReader (org.datavec.api.records.reader.RecordReader)2 DoubleWritable (org.datavec.api.writable.DoubleWritable)2 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)2 Test (org.junit.Test)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 URI (java.net.URI)1 ArrayList (java.util.ArrayList)1 Collection (java.util.Collection)1 FSDataOutputStream (org.apache.hadoop.fs.FSDataOutputStream)1 FileSystem (org.apache.hadoop.fs.FileSystem)1 Path (org.apache.hadoop.fs.Path)1 SelfWritableConverter (org.datavec.api.io.converters.SelfWritableConverter)1 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)1 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)1 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)1 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)1 IntWritable (org.datavec.api.writable.IntWritable)1 NDArrayWritable (org.datavec.common.data.NDArrayWritable)1 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)1