use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testCSVLoadingRegression.
@Test
public void testCSVLoadingRegression() throws Exception {
int nLines = 30;
int nFeatures = 5;
int miniBatchSize = 10;
int labelIdx = 0;
String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "rr_csv_test_rand.csv");
double[][] data = makeRandomCSV(path, nLines, nFeatures);
RecordReader testReader = new CSVRecordReader();
testReader.initialize(new FileSplit(new File(path)));
DataSetIterator iter = new RecordReaderDataSetIterator(testReader, null, miniBatchSize, labelIdx, 1, true);
int miniBatch = 0;
while (iter.hasNext()) {
DataSet test = iter.next();
INDArray features = test.getFeatureMatrix();
INDArray labels = test.getLabels();
assertArrayEquals(new int[] { miniBatchSize, nFeatures }, features.shape());
assertArrayEquals(new int[] { miniBatchSize, 1 }, labels.shape());
int startRow = miniBatch * miniBatchSize;
for (int i = 0; i < miniBatchSize; i++) {
double labelExp = data[startRow + i][labelIdx];
double labelAct = labels.getDouble(i);
assertEquals(labelExp, labelAct, 1e-5f);
int featureCount = 0;
for (int j = 0; j < nFeatures + 1; j++) {
if (j == labelIdx)
continue;
double featureExp = data[startRow + i][j];
double featureAct = features.getDouble(i, featureCount++);
assertEquals(featureExp, featureAct, 1e-5f);
}
}
miniBatch++;
}
assertEquals(nLines / miniBatchSize, miniBatch);
}
use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.
the class MultipleEpochsIteratorTest method testLoadBatchDataSet.
@Test
public void testLoadBatchDataSet() throws Exception {
int epochs = 2;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(20);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertEquals(path.numExamples(), 10, 0.0);
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.
the class ConvolutionLayerSetupTest method testLRN.
@Test
public void testLRN() throws Exception {
List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
String rootDir = new ClassPathResource("lfwtest").getFile().getAbsolutePath();
RecordReader reader = new ImageRecordReader(28, 28, 3);
reader.initialize(new FileSplit(new File(rootDir)));
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
labels.remove("lfwtest");
NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
builder.setInputType(InputType.convolutional(28, 28, 3));
MultiLayerConfiguration conf = builder.build();
ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
assertEquals(6, layer2.getNIn());
}
use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.
the class TestComputationGraphNetwork method testIrisFitMultiDataSetIterator.
@Test
public void testIrisFitMultiDataSetIterator() throws Exception {
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
cg.fit(iter);
rr.reset();
iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
while (iter.hasNext()) {
cg.fit(iter.next());
}
}
use of org.datavec.api.split.FileSplit in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method buildGraph.
private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
CSVRecordReader reader = new CSVRecordReader(0, ",");
reader.initialize(new FileSplit(nodes));
List<Blogger> bloggers = new ArrayList<>();
int cnt = 0;
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
Blogger blogger = new Blogger(lines.get(0).toInt());
bloggers.add(blogger);
cnt++;
}
reader.close();
Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
// load edges
File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
reader = new CSVRecordReader(0, ",");
reader.initialize(new FileSplit(edges));
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
int from = lines.get(0).toInt();
int to = lines.get(1).toInt();
graph.addEdge(from - 1, to - 1, 1.0, false);
}
logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
assertEquals(119, graph.getConnectedVertices(0).size());
assertEquals(9, graph.getConnectedVertices(1).size());
assertEquals(6, graph.getConnectedVertices(3).size());
return graph;
}
Aggregations