use of org.datavec.api.records.reader.impl.collection.CollectionRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderDataSetIteratorNDArrayWritableLabels.
@Test
public void testRecordReaderDataSetIteratorNDArrayWritableLabels() {
Collection<Collection<Writable>> data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
data.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
RecordReader rr = new CollectionRecordReader(data);
int batchSize = 3;
int labelIndexFrom = 2;
int labelIndexTo = 2;
boolean regression = true;
DataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
DataSet ds = rrdsi.next();
INDArray expFeatures = Nd4j.create(new double[][] { { 0, 1 }, { 2, 3 }, { 4, 5 } });
INDArray expLabels = Nd4j.create(new double[][] { { 1.1, 2.1, 3.1 }, { 4.1, 5.1, 6.1 }, { 7.1, 8.1, 9.1 } });
assertEquals(expFeatures, ds.getFeatures());
assertEquals(expLabels, ds.getLabels());
//ALSO: test if we have NDArrayWritables for BOTH the features and the labels
data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 0, 1 })), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 2, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5 })), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
rr = new CollectionRecordReader(data);
rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
ds = rrdsi.next();
assertEquals(expFeatures, ds.getFeatures());
assertEquals(expLabels, ds.getLabels());
}
use of org.datavec.api.records.reader.impl.collection.CollectionRecordReader in project deeplearning4j by deeplearning4j.
the class TestRecordReaders method testClassIndexOutsideOfRangeRRDSI.
@Test
public void testClassIndexOutsideOfRangeRRDSI() {
Collection<Collection<Writable>> c = new ArrayList<>();
c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
CollectionRecordReader crr = new CollectionRecordReader(c);
RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);
try {
DataSet ds = iter.next();
fail("Expected exception");
} catch (DL4JException e) {
System.out.println("testClassIndexOutsideOfRange(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail();
}
}
use of org.datavec.api.records.reader.impl.collection.CollectionRecordReader in project deeplearning4j by deeplearning4j.
the class StringToDataSetExportFunction method processBatchIfRequired.
private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception {
if (list.isEmpty())
return;
if (list.size() < batchSize && !finalRecord)
return;
RecordReader rr = new CollectionRecordReader(list);
RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, new SelfWritableConverter(), batchSize, labelIndex, numPossibleLabels, regression);
DataSet ds = iter.next();
String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";
URI uri = new URI(outputDir.getPath() + "/" + filename);
FileSystem file = FileSystem.get(uri, conf);
try (FSDataOutputStream out = file.create(new Path(uri))) {
ds.save(out);
}
list.clear();
}
Aggregations