use of org.datavec.api.split.NumberedFileInputSplit in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.
@Test
public void testSplittingCSVSequence() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]);
}
assertFalse(srrmdsi.hasNext());
}
use of org.datavec.api.split.NumberedFileInputSplit in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequenceMeta.
@Test
public void testSplittingCSVSequenceMeta() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
srrmdsi.setCollectMetaData(true);
int count = 0;
while (srrmdsi.hasNext()) {
MultiDataSet mds = srrmdsi.next();
MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(3, count);
}
use of org.datavec.api.split.NumberedFileInputSplit in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testVariableLengthTSMeta.
@Test
public void testVariableLengthTSMeta() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
//Set up SequenceRecordReaderDataSetIterators for comparison
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabelsShort_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
//Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
rrmdsiStart.setCollectMetaData(true);
rrmdsiEnd.setCollectMetaData(true);
int count = 0;
while (rrmdsiStart.hasNext()) {
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));
assertEquals(mdsStart, mdsStartFromMeta);
assertEquals(mdsEnd, mdsEndFromMeta);
count++;
}
assertFalse(rrmdsiStart.hasNext());
assertFalse(rrmdsiEnd.hasNext());
assertEquals(3, count);
}
use of org.datavec.api.split.NumberedFileInputSplit in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testsBasic.
@Test
public void testsBasic() throws Exception {
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ",");
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(rrmdsi.hasNext());
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(srrmdsi.hasNext());
}
use of org.datavec.api.split.NumberedFileInputSplit in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testVariableLengthTS.
@Test
public void testVariableLengthTS() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
//Set up SequenceRecordReaderDataSetIterators for comparison
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabelsShort_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);
SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
while (iterAlignStart.hasNext()) {
DataSet dsStart = iterAlignStart.next();
DataSet dsEnd = iterAlignEnd.next();
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
assertEquals(1, mdsStart.getFeatures().length);
assertEquals(1, mdsStart.getLabels().length);
//assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
assertEquals(1, mdsStart.getLabelsMaskArrays().length);
assertEquals(1, mdsEnd.getFeatures().length);
assertEquals(1, mdsEnd.getLabels().length);
//assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
assertEquals(1, mdsEnd.getLabelsMaskArrays().length);
assertEquals(dsStart.getFeatureMatrix(), mdsStart.getFeatures(0));
assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));
assertEquals(dsEnd.getFeatureMatrix(), mdsEnd.getFeatures(0));
assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
}
assertFalse(rrmdsiStart.hasNext());
assertFalse(rrmdsiEnd.hasNext());
}
Aggregations