use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method loadFromMetaData.
/**
* Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
*
* @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
* to the SequenceRecordReaderDataSetIterator constructor
* @return DataSet with the specified examples
* @throws IOException If an error occurs during loading of the data
*/
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
//First: load the next values from the RR / SeqRRs
Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
RecordReader rr = entry.getValue();
List<RecordMetaData> thisRRMeta = new ArrayList<>();
for (RecordMetaData m : list) {
RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
thisRRMeta.add(m2.getMeta().get(entry.getKey()));
}
List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
List<List<Writable>> writables = new ArrayList<>(list.size());
for (Record r : fromMeta) {
writables.add(r.getRecord());
}
nextRRVals.put(entry.getKey(), writables);
}
for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
SequenceRecordReader rr = entry.getValue();
List<RecordMetaData> thisRRMeta = new ArrayList<>();
for (RecordMetaData m : list) {
RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
thisRRMeta.add(m2.getMeta().get(entry.getKey()));
}
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
List<List<List<Writable>>> writables = new ArrayList<>(list.size());
for (SequenceRecord r : fromMeta) {
writables.add(r.getSequenceRecord());
}
nextSeqRRVals.put(entry.getKey(), writables);
}
return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderMaxBatchLimit.
@Test
public void testRecordReaderMaxBatchLimit() throws Exception {
RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
recordReader.initialize(csv);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 10, -1, -1, 2);
iter.next();
iter.next();
assertEquals(false, iter.hasNext());
}
use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderDataSetIteratorNDArrayWritableLabels.
@Test
public void testRecordReaderDataSetIteratorNDArrayWritableLabels() {
Collection<Collection<Writable>> data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
data.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
RecordReader rr = new CollectionRecordReader(data);
int batchSize = 3;
int labelIndexFrom = 2;
int labelIndexTo = 2;
boolean regression = true;
DataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
DataSet ds = rrdsi.next();
INDArray expFeatures = Nd4j.create(new double[][] { { 0, 1 }, { 2, 3 }, { 4, 5 } });
INDArray expLabels = Nd4j.create(new double[][] { { 1.1, 2.1, 3.1 }, { 4.1, 5.1, 6.1 }, { 7.1, 8.1, 9.1 } });
assertEquals(expFeatures, ds.getFeatures());
assertEquals(expLabels, ds.getLabels());
//ALSO: test if we have NDArrayWritables for BOTH the features and the labels
data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 0, 1 })), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 2, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5 })), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
rr = new CollectionRecordReader(data);
rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
ds = rrdsi.next();
assertEquals(expFeatures, ds.getFeatures());
assertEquals(expLabels, ds.getLabels());
}
use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderMultiRegression.
@Test
public void testRecordReaderMultiRegression() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 3;
int labelIdxFrom = 3;
int labelIdxTo = 4;
DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, labelIdxFrom, labelIdxTo, true);
DataSet ds = iter.next();
INDArray f = ds.getFeatureMatrix();
INDArray l = ds.getLabels();
assertArrayEquals(new int[] { 3, 3 }, f.shape());
assertArrayEquals(new int[] { 3, 2 }, l.shape());
//Check values:
double[][] fExpD = new double[][] { { 5.1, 3.5, 1.4 }, { 4.9, 3.0, 1.4 }, { 4.7, 3.2, 1.3 } };
double[][] lExpD = new double[][] { { 0.2, 0 }, { 0.2, 0 }, { 0.2, 0 } };
INDArray fExp = Nd4j.create(fExpD);
INDArray lExp = Nd4j.create(lExpD);
assertEquals(fExp, f);
assertEquals(lExp, l);
}
use of org.datavec.api.records.reader.RecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReader.
@Test
public void testRecordReader() throws Exception {
RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
recordReader.initialize(csv);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
DataSet next = iter.next();
assertEquals(34, next.numExamples());
}
Aggregations