Search in sources :

Example 1 with Record

use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetIterator method next.

@Override
public DataSet next(int num) {
    if (useCurrent) {
        useCurrent = false;
        if (preProcessor != null)
            preProcessor.preProcess(last);
        return last;
    }
    List<DataSet> dataSets = new ArrayList<>();
    List<RecordMetaData> meta = (collectMetaData ? new ArrayList<RecordMetaData>() : null);
    for (int i = 0; i < num; i++) {
        if (!hasNext())
            break;
        if (recordReader instanceof SequenceRecordReader) {
            if (sequenceIter == null || !sequenceIter.hasNext()) {
                List<List<Writable>> sequenceRecord = ((SequenceRecordReader) recordReader).sequenceRecord();
                sequenceIter = sequenceRecord.iterator();
            }
            List<Writable> record = sequenceIter.next();
            dataSets.add(getDataSet(record));
        } else {
            if (collectMetaData) {
                Record record = recordReader.nextRecord();
                dataSets.add(getDataSet(record.getRecord()));
                meta.add(record.getMetaData());
            } else {
                List<Writable> record = recordReader.next();
                dataSets.add(getDataSet(record));
            }
        }
    }
    batchNum++;
    if (dataSets.isEmpty())
        return new DataSet();
    DataSet ret = DataSet.merge(dataSets);
    if (collectMetaData) {
        ret.setExampleMetaData(meta);
    }
    last = ret;
    if (preProcessor != null)
        preProcessor.preProcess(ret);
    //Add label name values to dataset
    if (recordReader.getLabels() != null)
        ret.setLabelNames(recordReader.getLabels());
    return ret;
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) ArrayList(java.util.ArrayList) List(java.util.List) Record(org.datavec.api.records.Record)

Example 2 with Record

use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method loadFromMetaData.

/**
     * Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the SequenceRecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
    //First: load the next values from the RR / SeqRRs
    Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
    Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
    List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
    for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
        RecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
        List<List<Writable>> writables = new ArrayList<>(list.size());
        for (Record r : fromMeta) {
            writables.add(r.getRecord());
        }
        nextRRVals.put(entry.getKey(), writables);
    }
    for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
        SequenceRecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
        List<List<List<Writable>>> writables = new ArrayList<>(list.size());
        for (SequenceRecord r : fromMeta) {
            writables.add(r.getSequenceRecord());
        }
        nextSeqRRVals.put(entry.getKey(), writables);
    }
    return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) RecordReader(org.datavec.api.records.reader.RecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) SequenceRecord(org.datavec.api.records.SequenceRecord) SequenceRecord(org.datavec.api.records.SequenceRecord) Record(org.datavec.api.records.Record) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 3 with Record

use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetIterator method loadFromMetaData.

/**
     * Load a multiple examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the RecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
    List<Record> records = recordReader.loadFromMetaData(list);
    List<DataSet> dataSets = new ArrayList<>();
    List<RecordMetaData> meta = new ArrayList<>();
    for (Record r : records) {
        dataSets.add(getDataSet(r.getRecord()));
        meta.add(r.getMetaData());
    }
    if (dataSets.isEmpty()) {
        return new DataSet();
    }
    DataSet ret = DataSet.merge(dataSets);
    ret.setExampleMetaData(meta);
    last = ret;
    if (preProcessor != null)
        preProcessor.preProcess(ret);
    if (recordReader.getLabels() != null)
        ret.setLabelNames(recordReader.getLabels());
    return ret;
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) Record(org.datavec.api.records.Record)

Example 4 with Record

use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method next.

@Override
public MultiDataSet next(int num) {
    if (!hasNext())
        throw new NoSuchElementException("No next elements");
    //First: load the next values from the RR / SeqRRs
    Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
    Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
    List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
    for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
        RecordReader rr = entry.getValue();
        List<List<Writable>> writables = new ArrayList<>(num);
        for (int i = 0; i < num && rr.hasNext(); i++) {
            List<Writable> record;
            if (collectMetaData) {
                Record r = rr.nextRecord();
                record = r.getRecord();
                if (nextMetas.size() <= i) {
                    nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
                }
                RecordMetaDataComposableMap map = nextMetas.get(i);
                map.getMeta().put(entry.getKey(), r.getMetaData());
            } else {
                record = rr.next();
            }
            writables.add(record);
        }
        nextRRVals.put(entry.getKey(), writables);
    }
    for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
        SequenceRecordReader rr = entry.getValue();
        List<List<List<Writable>>> writables = new ArrayList<>(num);
        for (int i = 0; i < num && rr.hasNext(); i++) {
            List<List<Writable>> sequence;
            if (collectMetaData) {
                SequenceRecord r = rr.nextSequence();
                sequence = r.getSequenceRecord();
                if (nextMetas.size() <= i) {
                    nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
                }
                RecordMetaDataComposableMap map = nextMetas.get(i);
                map.getMeta().put(entry.getKey(), r.getMetaData());
            } else {
                sequence = rr.sequenceRecord();
            }
            writables.add(sequence);
        }
        nextSeqRRVals.put(entry.getKey(), writables);
    }
    return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Also used : SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) RecordReader(org.datavec.api.records.reader.RecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) SequenceRecord(org.datavec.api.records.SequenceRecord) SequenceRecord(org.datavec.api.records.SequenceRecord) Record(org.datavec.api.records.Record) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 5 with Record

use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReaderMetaData.

@Test
public void testRecordReaderMetaData() throws Exception {
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    int batchSize = 10;
    int labelIdx = 4;
    int numClasses = 3;
    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
    rrdsi.setCollectMetaData(true);
    while (rrdsi.hasNext()) {
        DataSet ds = rrdsi.next();
        List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
        int i = 0;
        for (RecordMetaData m : meta) {
            Record r = csv.loadFromMetaData(m);
            INDArray row = ds.getFeatureMatrix().getRow(i);
            System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
            for (int j = 0; j < 4; j++) {
                double exp = r.getRecord().get(j).toDouble();
                double act = row.getDouble(j);
                assertEquals(exp, act, 1e-6);
            }
            i++;
        }
        System.out.println();
        DataSet fromMeta = rrdsi.loadFromMetaData(meta);
        assertEquals(ds, fromMeta);
    }
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) Record(org.datavec.api.records.Record) FileSplit(org.datavec.api.split.FileSplit) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

Record (org.datavec.api.records.Record)5 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)4 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)4 RecordReader (org.datavec.api.records.reader.RecordReader)3 Writable (org.datavec.api.writable.Writable)3 NDArrayWritable (org.datavec.common.data.NDArrayWritable)3 DataSet (org.nd4j.linalg.dataset.DataSet)3 ArrayList (java.util.ArrayList)2 SequenceRecord (org.datavec.api.records.SequenceRecord)2 RecordMetaDataComposableMap (org.datavec.api.records.metadata.RecordMetaDataComposableMap)2 List (java.util.List)1 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)1 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)1 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)1 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)1 FileSplit (org.datavec.api.split.FileSplit)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)1