Search in sources :

Example 1 with RecordReader

use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method loadFromMetaData.

/**
     * Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the SequenceRecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
    //First: load the next values from the RR / SeqRRs
    Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
    Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
    List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
    for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
        RecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
        List<List<Writable>> writables = new ArrayList<>(list.size());
        for (Record r : fromMeta) {
            writables.add(r.getRecord());
        }
        nextRRVals.put(entry.getKey(), writables);
    }
    for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
        SequenceRecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
        List<List<List<Writable>>> writables = new ArrayList<>(list.size());
        for (SequenceRecord r : fromMeta) {
            writables.add(r.getSequenceRecord());
        }
        nextSeqRRVals.put(entry.getKey(), writables);
    }
    return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) RecordReader(org.datavec.api.records.reader.RecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) SequenceRecord(org.datavec.api.records.SequenceRecord) SequenceRecord(org.datavec.api.records.SequenceRecord) Record(org.datavec.api.records.Record) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 2 with RecordReader

use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReaderMaxBatchLimit.

@Test
public void testRecordReaderMaxBatchLimit() throws Exception {
    RecordReader recordReader = new CSVRecordReader();
    FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
    recordReader.initialize(csv);
    DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 10, -1, -1, 2);
    iter.next();
    iter.next();
    assertEquals(false, iter.hasNext());
}
Also used : RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 3 with RecordReader

use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReaderDataSetIteratorNDArrayWritableLabels.

@Test
public void testRecordReaderDataSetIteratorNDArrayWritableLabels() {
    Collection<Collection<Writable>> data = new ArrayList<>();
    data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
    data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
    data.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
    RecordReader rr = new CollectionRecordReader(data);
    int batchSize = 3;
    int labelIndexFrom = 2;
    int labelIndexTo = 2;
    boolean regression = true;
    DataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
    DataSet ds = rrdsi.next();
    INDArray expFeatures = Nd4j.create(new double[][] { { 0, 1 }, { 2, 3 }, { 4, 5 } });
    INDArray expLabels = Nd4j.create(new double[][] { { 1.1, 2.1, 3.1 }, { 4.1, 5.1, 6.1 }, { 7.1, 8.1, 9.1 } });
    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
    //ALSO: test if we have NDArrayWritables for BOTH the features and the labels
    data = new ArrayList<>();
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 0, 1 })), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 2, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5 })), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
    rr = new CollectionRecordReader(data);
    rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
    ds = rrdsi.next();
    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 4 with RecordReader

use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReaderMultiRegression.

@Test
public void testRecordReaderMultiRegression() throws Exception {
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    int batchSize = 3;
    int labelIdxFrom = 3;
    int labelIdxTo = 4;
    DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, labelIdxFrom, labelIdxTo, true);
    DataSet ds = iter.next();
    INDArray f = ds.getFeatureMatrix();
    INDArray l = ds.getLabels();
    assertArrayEquals(new int[] { 3, 3 }, f.shape());
    assertArrayEquals(new int[] { 3, 2 }, l.shape());
    //Check values:
    double[][] fExpD = new double[][] { { 5.1, 3.5, 1.4 }, { 4.9, 3.0, 1.4 }, { 4.7, 3.2, 1.3 } };
    double[][] lExpD = new double[][] { { 0.2, 0 }, { 0.2, 0 }, { 0.2, 0 } };
    INDArray fExp = Nd4j.create(fExpD);
    INDArray lExp = Nd4j.create(lExpD);
    assertEquals(fExp, f);
    assertEquals(lExp, l);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 5 with RecordReader

use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReader.

@Test
public void testRecordReader() throws Exception {
    RecordReader recordReader = new CSVRecordReader();
    FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
    recordReader.initialize(csv);
    DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
    DataSet next = iter.next();
    assertEquals(34, next.numExamples());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Aggregations

RecordReader (org.datavec.api.records.reader.RecordReader)21 Test (org.junit.Test)18 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)17 FileSplit (org.datavec.api.split.FileSplit)17 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)13 DataSet (org.nd4j.linalg.dataset.DataSet)13 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)12 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)11 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)10 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)6 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)6 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)6 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)5 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)5 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)5 ClassPathResource (org.datavec.api.util.ClassPathResource)4 Record (org.datavec.api.records.Record)3 NDArrayWritable (org.datavec.common.data.NDArrayWritable)3