use of org.datavec.api.records.reader.SequenceRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method next.
@Override
public MultiDataSet next(int num) {
if (!hasNext())
throw new NoSuchElementException("No next elements");
//First: load the next values from the RR / SeqRRs
Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
RecordReader rr = entry.getValue();
List<List<Writable>> writables = new ArrayList<>(num);
for (int i = 0; i < num && rr.hasNext(); i++) {
List<Writable> record;
if (collectMetaData) {
Record r = rr.nextRecord();
record = r.getRecord();
if (nextMetas.size() <= i) {
nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
}
RecordMetaDataComposableMap map = nextMetas.get(i);
map.getMeta().put(entry.getKey(), r.getMetaData());
} else {
record = rr.next();
}
writables.add(record);
}
nextRRVals.put(entry.getKey(), writables);
}
for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
SequenceRecordReader rr = entry.getValue();
List<List<List<Writable>>> writables = new ArrayList<>(num);
for (int i = 0; i < num && rr.hasNext(); i++) {
List<List<Writable>> sequence;
if (collectMetaData) {
SequenceRecord r = rr.nextSequence();
sequence = r.getSequenceRecord();
if (nextMetas.size() <= i) {
nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
}
RecordMetaDataComposableMap map = nextMetas.get(i);
map.getMeta().put(entry.getKey(), r.getMetaData());
} else {
sequence = rr.sequenceRecord();
}
writables.add(sequence);
}
nextSeqRRVals.put(entry.getKey(), writables);
}
return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
use of org.datavec.api.records.reader.SequenceRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.
@Test
public void testSplittingCSVSequence() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]);
}
assertFalse(srrmdsi.hasNext());
}
use of org.datavec.api.records.reader.SequenceRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequenceMeta.
@Test
public void testSplittingCSVSequenceMeta() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
srrmdsi.setCollectMetaData(true);
int count = 0;
while (srrmdsi.hasNext()) {
MultiDataSet mds = srrmdsi.next();
MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(3, count);
}
use of org.datavec.api.records.reader.SequenceRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testVariableLengthTSMeta.
@Test
public void testVariableLengthTSMeta() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
//Set up SequenceRecordReaderDataSetIterators for comparison
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabelsShort_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
//Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
rrmdsiStart.setCollectMetaData(true);
rrmdsiEnd.setCollectMetaData(true);
int count = 0;
while (rrmdsiStart.hasNext()) {
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));
assertEquals(mdsStart, mdsStartFromMeta);
assertEquals(mdsEnd, mdsEndFromMeta);
count++;
}
assertFalse(rrmdsiStart.hasNext());
assertFalse(rrmdsiEnd.hasNext());
assertEquals(3, count);
}
use of org.datavec.api.records.reader.SequenceRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testsBasic.
@Test
public void testsBasic() throws Exception {
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ",");
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(rrmdsi.hasNext());
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(srrmdsi.hasNext());
}
Aggregations