use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReaderMultiRegression.
@Test
public void testRecordReaderMultiRegression() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 3;
int labelIdxFrom = 3;
int labelIdxTo = 4;
DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, labelIdxFrom, labelIdxTo, true);
DataSet ds = iter.next();
INDArray f = ds.getFeatureMatrix();
INDArray l = ds.getLabels();
assertArrayEquals(new int[] { 3, 3 }, f.shape());
assertArrayEquals(new int[] { 3, 2 }, l.shape());
//Check values:
double[][] fExpD = new double[][] { { 5.1, 3.5, 1.4 }, { 4.9, 3.0, 1.4 }, { 4.7, 3.2, 1.3 } };
double[][] lExpD = new double[][] { { 0.2, 0 }, { 0.2, 0 }, { 0.2, 0 } };
INDArray fExp = Nd4j.create(fExpD);
INDArray lExp = Nd4j.create(lExpD);
assertEquals(fExp, f);
assertEquals(lExp, l);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testRecordReader.
@Test
public void testRecordReader() throws Exception {
RecordReader recordReader = new CSVRecordReader();
FileSplit csv = new FileSplit(new ClassPathResource("csv-example.csv").getTempFileFromArchive());
recordReader.initialize(csv);
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 34);
DataSet next = iter.next();
assertEquals(34, next.numExamples());
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestRecordReaders method testClassIndexOutsideOfRangeRRMDSI_MultipleReaders.
@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {
Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
Collection<Collection<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
c1.add(seq1);
Collection<Collection<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
c1.add(seq2);
Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
Collection<Collection<Writable>> seq1a = new ArrayList<>();
seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
c2.add(seq1a);
Collection<Collection<Writable>> seq2a = new ArrayList<>();
seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
c2.add(seq2a);
CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);
try {
DataSet ds = dsi.next();
fail("Expected exception");
} catch (DL4JException e) {
System.out.println("testClassIndexOutsideOfRangeRRMDSI_MultipleReaders(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail();
}
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class BNGradientCheckTest method testGradient2dSimple.
@Test
public void testGradient2dSimple() {
DataNormalization scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter);
iter.setPreProcessor(scaler);
DataSet ds = iter.next();
INDArray input = ds.getFeatureMatrix();
INDArray labels = ds.getLabels();
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().learningRate(1.0).regularization(false).updater(Updater.NONE).seed(12345L).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).pretrain(false).backprop(true);
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
if (PRINT_RESULTS) {
for (int j = 0; j < mln.getnLayers(); j++) System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class BNGradientCheckTest method testGradient2dFixedGammaBeta.
@Test
public void testGradient2dFixedGammaBeta() {
DataNormalization scaler = new NormalizerMinMaxScaler();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
scaler.fit(iter);
iter.setPreProcessor(scaler);
DataSet ds = iter.next();
INDArray input = ds.getFeatureMatrix();
INDArray labels = ds.getLabels();
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().learningRate(1.0).regularization(false).updater(Updater.NONE).seed(12345L).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 1)).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().lockGammaBeta(true).gamma(2.0).beta(0.5).nOut(3).build()).layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).pretrain(false).backprop(true);
MultiLayerNetwork mln = new MultiLayerNetwork(builder.build());
mln.init();
if (PRINT_RESULTS) {
for (int j = 0; j < mln.getnLayers(); j++) System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
}
Aggregations