use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class DataSetIteratorTest method testIteratorDataSetIteratorSplitting.
@Test
public void testIteratorDataSetIteratorSplitting() {
//Test splitting large data sets into smaller ones
int origBatchSize = 4;
int origNumDSs = 3;
int batchSize = 3;
int numBatches = 4;
int featureSize = 5;
int labelSize = 6;
Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < origNumDSs; i++) {
INDArray features = Nd4j.rand(origBatchSize, featureSize);
INDArray labels = Nd4j.rand(origBatchSize, labelSize);
orig.add(new DataSet(features, labels));
}
List<DataSet> expected = new ArrayList<>();
expected.add(new DataSet(orig.get(0).getFeatureMatrix().getRows(0, 1, 2), orig.get(0).getLabels().getRows(0, 1, 2)));
expected.add(new DataSet(Nd4j.vstack(orig.get(0).getFeatureMatrix().getRows(3), orig.get(1).getFeatureMatrix().getRows(0, 1)), Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
expected.add(new DataSet(Nd4j.vstack(orig.get(1).getFeatureMatrix().getRows(2, 3), orig.get(2).getFeatureMatrix().getRows(0)), Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
expected.add(new DataSet(orig.get(2).getFeatureMatrix().getRows(1, 2, 3), orig.get(2).getLabels().getRows(1, 2, 3)));
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0;
while (iter.hasNext()) {
DataSet ds = iter.next();
assertEquals(expected.get(count), ds);
count++;
}
assertEquals(count, numBatches);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class DataSetIteratorTest method testBatchSizeOfOneIris.
@Test
public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and
//(b) Labels are a proper one-hot vector (i.e., sum is 1.0)
//Iris:
DataSetIterator iris = new IrisDataSetIterator(1, 5);
int irisC = 0;
while (iris.hasNext()) {
irisC++;
DataSet ds = iris.next();
assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
}
assertEquals(5, irisC);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class MultipleEpochsIteratorTest method testNextAndReset.
@Test
public void testNextAndReset() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class MultipleEpochsIteratorTest method testLoadFullDataSet.
@Test
public void testLoadFullDataSet() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(50);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertEquals(path.numExamples(), 50, 0.0);
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class ROCTest method RocEvalSanityCheck.
@Test
public void RocEvalSanityCheck() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
iter.setPreProcessor(ns);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROCMultiClass roc = net.evaluateROCMultiClass(iter, 32);
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
ROCMultiClass manual = new ROCMultiClass(32);
manual.eval(l, out);
for (int i = 0; i < 3; i++) {
assertEquals(manual.calculateAUC(i), roc.calculateAUC(i), 1e-6);
double[][] rocCurve = roc.getResultsAsArray(i);
double[][] rocManual = manual.getResultsAsArray(i);
assertArrayEquals(rocCurve[0], rocManual[0], 1e-6);
assertArrayEquals(rocCurve[1], rocManual[1], 1e-6);
}
}
Aggregations