use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderMaxBatchLimit.
@Test
public void testRecordReaderMaxBatchLimit() throws Exception {
RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
recordReader.initialize(csv);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 10, -1, -1, 2);
iter.next();
iter.next();
assertEquals(false, iter.hasNext());
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderMultiRegression.
@Test
public void testRecordReaderMultiRegression() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 3;
int labelIdxFrom = 3;
int labelIdxTo = 4;
DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, labelIdxFrom, labelIdxTo, true);
DataSet ds = iter.next();
INDArray f = ds.getFeatureMatrix();
INDArray l = ds.getLabels();
assertArrayEquals(new int[] { 3, 3 }, f.shape());
assertArrayEquals(new int[] { 3, 2 }, l.shape());
//Check values:
double[][] fExpD = new double[][] { { 5.1, 3.5, 1.4 }, { 4.9, 3.0, 1.4 }, { 4.7, 3.2, 1.3 } };
double[][] lExpD = new double[][] { { 0.2, 0 }, { 0.2, 0 }, { 0.2, 0 } };
INDArray fExp = Nd4j.create(fExpD);
INDArray lExp = Nd4j.create(lExpD);
assertEquals(fExp, f);
assertEquals(lExp, l);
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReader.
@Test
public void testRecordReader() throws Exception {
RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
recordReader.initialize(csv);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
DataSet next = iter.next();
assertEquals(34, next.numExamples());
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class DataSetIteratorTest method testMnist.
@Test
public void testMnist() throws Exception {
ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
CSVRecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
while (dsi.hasNext()) {
DataSet dsExp = dsi.next();
DataSet dsAct = iter.next();
INDArray fExp = dsExp.getFeatureMatrix();
fExp.divi(255);
INDArray lExp = dsExp.getLabels();
INDArray fAct = dsAct.getFeatureMatrix();
INDArray lAct = dsAct.getLabels();
assertEquals(fExp, fAct);
assertEquals(lExp, lAct);
}
assertFalse(iter.hasNext());
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class MultipleEpochsIteratorTest method testNextAndReset.
@Test
public void testNextAndReset() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
Aggregations