use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetIterator method next.
@Override
public DataSet next(int num) {
if (useCurrent) {
useCurrent = false;
if (preProcessor != null)
preProcessor.preProcess(last);
return last;
}
List<DataSet> dataSets = new ArrayList<>();
List<RecordMetaData> meta = (collectMetaData ? new ArrayList<RecordMetaData>() : null);
for (int i = 0; i < num; i++) {
if (!hasNext())
break;
if (recordReader instanceof SequenceRecordReader) {
if (sequenceIter == null || !sequenceIter.hasNext()) {
List<List<Writable>> sequenceRecord = ((SequenceRecordReader) recordReader).sequenceRecord();
sequenceIter = sequenceRecord.iterator();
}
List<Writable> record = sequenceIter.next();
dataSets.add(getDataSet(record));
} else {
if (collectMetaData) {
Record record = recordReader.nextRecord();
dataSets.add(getDataSet(record.getRecord()));
meta.add(record.getMetaData());
} else {
List<Writable> record = recordReader.next();
dataSets.add(getDataSet(record));
}
}
}
batchNum++;
if (dataSets.isEmpty())
return new DataSet();
DataSet ret = DataSet.merge(dataSets);
if (collectMetaData) {
ret.setExampleMetaData(meta);
}
last = ret;
if (preProcessor != null)
preProcessor.preProcess(ret);
//Add label name values to dataset
if (recordReader.getLabels() != null)
ret.setLabelNames(recordReader.getLabels());
return ret;
}
use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method loadFromMetaData.
/**
* Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
*
* @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
* to the SequenceRecordReaderDataSetIterator constructor
* @return DataSet with the specified examples
* @throws IOException If an error occurs during loading of the data
*/
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
//First: load the next values from the RR / SeqRRs
Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
RecordReader rr = entry.getValue();
List<RecordMetaData> thisRRMeta = new ArrayList<>();
for (RecordMetaData m : list) {
RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
thisRRMeta.add(m2.getMeta().get(entry.getKey()));
}
List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
List<List<Writable>> writables = new ArrayList<>(list.size());
for (Record r : fromMeta) {
writables.add(r.getRecord());
}
nextRRVals.put(entry.getKey(), writables);
}
for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
SequenceRecordReader rr = entry.getValue();
List<RecordMetaData> thisRRMeta = new ArrayList<>();
for (RecordMetaData m : list) {
RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
thisRRMeta.add(m2.getMeta().get(entry.getKey()));
}
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
List<List<List<Writable>>> writables = new ArrayList<>(list.size());
for (SequenceRecord r : fromMeta) {
writables.add(r.getSequenceRecord());
}
nextSeqRRVals.put(entry.getKey(), writables);
}
return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetIterator method loadFromMetaData.
/**
* Load a multiple examples to a DataSet, using the provided RecordMetaData instances.
*
* @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
* to the RecordReaderDataSetIterator constructor
* @return DataSet with the specified examples
* @throws IOException If an error occurs during loading of the data
*/
public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
List<Record> records = recordReader.loadFromMetaData(list);
List<DataSet> dataSets = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
for (Record r : records) {
dataSets.add(getDataSet(r.getRecord()));
meta.add(r.getMetaData());
}
if (dataSets.isEmpty()) {
return new DataSet();
}
DataSet ret = DataSet.merge(dataSets);
ret.setExampleMetaData(meta);
last = ret;
if (preProcessor != null)
preProcessor.preProcess(ret);
if (recordReader.getLabels() != null)
ret.setLabelNames(recordReader.getLabels());
return ret;
}
use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method next.
@Override
public MultiDataSet next(int num) {
if (!hasNext())
throw new NoSuchElementException("No next elements");
//First: load the next values from the RR / SeqRRs
Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
RecordReader rr = entry.getValue();
List<List<Writable>> writables = new ArrayList<>(num);
for (int i = 0; i < num && rr.hasNext(); i++) {
List<Writable> record;
if (collectMetaData) {
Record r = rr.nextRecord();
record = r.getRecord();
if (nextMetas.size() <= i) {
nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
}
RecordMetaDataComposableMap map = nextMetas.get(i);
map.getMeta().put(entry.getKey(), r.getMetaData());
} else {
record = rr.next();
}
writables.add(record);
}
nextRRVals.put(entry.getKey(), writables);
}
for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
SequenceRecordReader rr = entry.getValue();
List<List<List<Writable>>> writables = new ArrayList<>(num);
for (int i = 0; i < num && rr.hasNext(); i++) {
List<List<Writable>> sequence;
if (collectMetaData) {
SequenceRecord r = rr.nextSequence();
sequence = r.getSequenceRecord();
if (nextMetas.size() <= i) {
nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
}
RecordMetaDataComposableMap map = nextMetas.get(i);
map.getMeta().put(entry.getKey(), r.getMetaData());
} else {
sequence = rr.sequenceRecord();
}
writables.add(sequence);
}
nextSeqRRVals.put(entry.getKey(), writables);
}
return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
use of org.datavec.api.records.Record in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderMetaData.
@Test
public void testRecordReaderMetaData() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 10;
int labelIdx = 4;
int numClasses = 3;
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
rrdsi.setCollectMetaData(true);
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
int i = 0;
for (RecordMetaData m : meta) {
Record r = csv.loadFromMetaData(m);
INDArray row = ds.getFeatureMatrix().getRow(i);
System.out.println(m.getLocation() + "\t" + r.getRecord() + "\t" + row);
for (int j = 0; j < 4; j++) {
double exp = r.getRecord().get(j).toDouble();
double act = row.getDouble(j);
assertEquals(exp, act, 1e-6);
}
i++;
}
System.out.println();
DataSet fromMeta = rrdsi.loadFromMetaData(meta);
assertEquals(ds, fromMeta);
}
}
Aggregations