use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReader.
@Test
public void testRecordReader() throws Exception {
RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
recordReader.initialize(csv);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
DataSet next = iter.next();
assertEquals(34, next.numExamples());
}
use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testSequenceRecordReaderRegression.
@Test
public void testSequenceRecordReaderRegression() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequence_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 0, true);
assertEquals(3, iter.inputColumns());
assertEquals(3, iter.totalOutcomes());
List<DataSet> dsList = new ArrayList<>();
while (iter.hasNext()) {
dsList.add(iter.next());
}
//3 files
assertEquals(3, dsList.size());
for (int i = 0; i < 3; i++) {
DataSet ds = dsList.get(i);
INDArray features = ds.getFeatureMatrix();
INDArray labels = ds.getLabels();
//1 examples, 3 values, 4 time steps
assertArrayEquals(new int[] { 1, 3, 4 }, features.shape());
assertArrayEquals(new int[] { 1, 3, 4 }, labels.shape());
assertEquals(features, labels);
}
//Also test regression + reset from a single reader:
featureReader.reset();
iter = new SequenceRecordReaderDataSetIterator(featureReader, 1, 0, 2, true);
int count = 0;
while (iter.hasNext()) {
DataSet ds = iter.next();
assertEquals(2, ds.getFeatureMatrix().size(1));
assertEquals(1, ds.getLabels().size(1));
count++;
}
assertEquals(3, count);
iter.reset();
count = 0;
while (iter.hasNext()) {
iter.next();
count++;
}
assertEquals(3, count);
}
use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class DataSetIteratorTest method testMnist.
@Test
public void testMnist() throws Exception {
ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
CSVRecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
while (dsi.hasNext()) {
DataSet dsExp = dsi.next();
DataSet dsAct = iter.next();
INDArray fExp = dsExp.getFeatureMatrix();
fExp.divi(255);
INDArray lExp = dsExp.getLabels();
INDArray fAct = dsAct.getFeatureMatrix();
INDArray lAct = dsAct.getLabels();
assertEquals(fExp, fAct);
assertEquals(lExp, lAct);
}
assertFalse(iter.hasNext());
}
use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.
@Test
public void testImagesRRDMSI() throws Exception {
File parentDir = Files.createTempDir();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
Random r = new Random(12345);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
for (int i = 0; i < 2; i++) {
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
}
}
use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVMeta.
@Test
public void testSplittingCSVMeta() throws Exception {
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Inputs: columns 0 and 1-2
//Outputs: columns 3, and 4->OneHot
RecordReader rr2 = new CSVRecordReader(0, ",");
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(150 / 10, count);
}
Aggregations