use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testCSVLoadingRegression.
@Test
public void testCSVLoadingRegression() throws Exception {
int nLines = 30;
int nFeatures = 5;
int miniBatchSize = 10;
int labelIdx = 0;
String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "rr_csv_test_rand.csv");
double[][] data = makeRandomCSV(path, nLines, nFeatures);
RecordReader testReader = new CSVRecordReader();
testReader.initialize(new FileSplit(new File(path)));
DataSetIterator iter = new RecordReaderDataSetIterator(testReader, null, miniBatchSize, labelIdx, 1, true);
int miniBatch = 0;
while (iter.hasNext()) {
DataSet test = iter.next();
INDArray features = test.getFeatureMatrix();
INDArray labels = test.getLabels();
assertArrayEquals(new int[] { miniBatchSize, nFeatures }, features.shape());
assertArrayEquals(new int[] { miniBatchSize, 1 }, labels.shape());
int startRow = miniBatch * miniBatchSize;
for (int i = 0; i < miniBatchSize; i++) {
double labelExp = data[startRow + i][labelIdx];
double labelAct = labels.getDouble(i);
assertEquals(labelExp, labelAct, 1e-5f);
int featureCount = 0;
for (int j = 0; j < nFeatures + 1; j++) {
if (j == labelIdx)
continue;
double featureExp = data[startRow + i][j];
double featureAct = features.getDouble(i, featureCount++);
assertEquals(featureExp, featureAct, 1e-5f);
}
}
miniBatch++;
}
assertEquals(nLines / miniBatchSize, miniBatch);
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class MultipleEpochsIteratorTest method testLoadBatchDataSet.
@Test
public void testLoadBatchDataSet() throws Exception {
int epochs = 2;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(20);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertEquals(path.numExamples(), 10, 0.0);
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class TestComputationGraphNetwork method testIrisFitMultiDataSetIterator.
@Test
public void testIrisFitMultiDataSetIterator() throws Exception {
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
cg.fit(iter);
rr.reset();
iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
while (iter.hasNext()) {
cg.fit(iter.next());
}
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method buildGraph.
private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
CSVRecordReader reader = new CSVRecordReader(0, ",");
reader.initialize(new FileSplit(nodes));
List<Blogger> bloggers = new ArrayList<>();
int cnt = 0;
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
Blogger blogger = new Blogger(lines.get(0).toInt());
bloggers.add(blogger);
cnt++;
}
reader.close();
Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
// load edges
File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
reader = new CSVRecordReader(0, ",");
reader.initialize(new FileSplit(edges));
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
int from = lines.get(0).toInt();
int to = lines.get(1).toInt();
graph.addEdge(from - 1, to - 1, 1.0, false);
}
logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
assertEquals(119, graph.getConnectedVertices(0).size());
assertEquals(9, graph.getConnectedVertices(1).size());
assertEquals(6, graph.getConnectedVertices(3).size());
return graph;
}
use of org.datavec.api.records.reader.impl.csv.CSVRecordReader in project deeplearning4j by deeplearning4j.
the class TestPreProcessedData method testCsvPreprocessedDataGenerationNoLabel.
@Test
public void testCsvPreprocessedDataGenerationNoLabel() throws Exception {
//Same as above test, but without any labels (in which case: input and output arrays are the same)
List<String> list = new ArrayList<>();
DataSetIterator iter = new IrisDataSetIterator(1, 150);
while (iter.hasNext()) {
DataSet ds = iter.next();
list.add(toString(ds.getFeatureMatrix(), Nd4j.argMax(ds.getLabels(), 1).getInt(0)));
}
JavaRDD<String> rdd = sc.parallelize(list);
int partitions = rdd.partitions().size();
URI tempDir = new File(System.getProperty("java.io.tmpdir")).toURI();
URI outputDir = new URI(tempDir.getPath() + "/dl4j_testPreprocessedData3");
File temp = new File(outputDir.getPath());
if (temp.exists())
FileUtils.deleteDirectory(temp);
int numBinFiles = 0;
try {
int batchSize = 5;
int labelIdx = -1;
int numPossibleLabels = -1;
rdd.foreachPartition(new StringToDataSetExportFunction(outputDir, new CSVRecordReader(0), batchSize, false, labelIdx, numPossibleLabels));
File[] fileList = new File(outputDir.getPath()).listFiles();
int totalExamples = 0;
for (File f2 : fileList) {
if (!f2.getPath().endsWith(".bin"))
continue;
// System.out.println(f2.getPath());
numBinFiles++;
DataSet ds = new DataSet();
ds.load(f2);
assertEquals(5, ds.numInputs());
assertEquals(5, ds.numOutcomes());
totalExamples += ds.numExamples();
}
assertEquals(150, totalExamples);
//Expect 30, give or take due to partitioning randomness
assertTrue(Math.abs(150 / batchSize - numBinFiles) <= partitions);
//Test the PortableDataStreamDataSetIterator:
JavaPairRDD<String, PortableDataStream> pds = sc.binaryFiles(outputDir.getPath());
List<PortableDataStream> pdsList = pds.values().collect();
DataSetIterator pdsIter = new PortableDataStreamDataSetIterator(pdsList);
int pdsCount = 0;
int totalExamples2 = 0;
while (pdsIter.hasNext()) {
DataSet ds = pdsIter.next();
pdsCount++;
totalExamples2 += ds.numExamples();
assertEquals(5, ds.numInputs());
assertEquals(5, ds.numOutcomes());
}
assertEquals(150, totalExamples2);
assertEquals(numBinFiles, pdsCount);
} finally {
FileUtils.deleteDirectory(temp);
}
}
Aggregations