use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.
the class LoadAndDraw method main.
/**
* @param args
*/
public static void main(String[] args) throws Exception {
MnistDataSetIterator iter = new MnistDataSetIterator(60, 60000);
@SuppressWarnings("unchecked") ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
try {
ois.close();
} catch (IOException e) {
}
DataSet test = null;
while (iter.hasNext()) {
test = iter.next();
INDArray reconstructed = network.activate(test.getFeatureMatrix());
for (int i = 0; i < test.numExamples(); i++) {
INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
INDArray reconstructed2 = reconstructed.getRow(i);
INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2).sample(reconstructed2.shape()).mul(255);
DrawReconstruction d = new DrawReconstruction(draw1);
d.title = "REAL";
d.draw();
DrawReconstruction d2 = new DrawReconstruction(draw2, 100, 100);
d2.title = "TEST";
d2.draw();
Thread.sleep(10000);
d.frame.dispose();
d2.frame.dispose();
}
}
}
use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.
the class DataSetIteratorTest method testBatchSizeOfOneMnist.
@Test
public void testBatchSizeOfOneMnist() throws Exception {
//MNIST:
DataSetIterator mnist = new MnistDataSetIterator(1, 5);
int mnistC = 0;
while (mnist.hasNext()) {
mnistC++;
DataSet ds = mnist.next();
assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
}
assertEquals(5, mnistC);
}
use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.
the class SamplingTest method testSample.
@Test
public void testSample() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(10, 10);
//batch size and total
DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
assertEquals(sampling.next().numExamples(), 10);
}
use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.
the class CNNProcessorTest method testCNNInputPreProcessorMnist.
@Test
public void testCNNInputPreProcessorMnist() throws Exception {
int numSamples = 1;
int batchSize = 1;
DataSet mnistIter = new MnistDataSetIterator(batchSize, numSamples, true).next();
MultiLayerNetwork model = getCNNMnistConfig();
model.init();
model.fit(mnistIter);
int val2to4 = model.getLayer(0).input().shape().length;
assertTrue(val2to4 == 4);
int val4to4 = model.getLayer(1).input().shape().length;
assertTrue(val4to4 == 4);
}
use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.
the class ConvolutionLayerTest method getMnistData.
public INDArray getMnistData() throws Exception {
int inputWidth = 28;
int inputHeight = 28;
int nChannelsIn = 1;
int nExamples = 5;
DataSetIterator data = new MnistDataSetIterator(nExamples, nExamples);
DataSet mnist = data.next();
nExamples = mnist.numExamples();
return mnist.getFeatureMatrix().reshape(nExamples, nChannelsIn, inputHeight, inputWidth);
}
Aggregations