Search in sources :

Example 21 with MnistDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.

the class LoadAndDraw method main.

/**
     * @param args
     */
public static void main(String[] args) throws Exception {
    MnistDataSetIterator iter = new MnistDataSetIterator(60, 60000);
    @SuppressWarnings("unchecked") ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
    BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
    try {
        ois.close();
    } catch (IOException e) {
    }
    DataSet test = null;
    while (iter.hasNext()) {
        test = iter.next();
        INDArray reconstructed = network.activate(test.getFeatureMatrix());
        for (int i = 0; i < test.numExamples(); i++) {
            INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
            INDArray reconstructed2 = reconstructed.getRow(i);
            INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2).sample(reconstructed2.shape()).mul(255);
            DrawReconstruction d = new DrawReconstruction(draw1);
            d.title = "REAL";
            d.draw();
            DrawReconstruction d2 = new DrawReconstruction(draw2, 100, 100);
            d2.title = "TEST";
            d2.draw();
            Thread.sleep(10000);
            d.frame.dispose();
            d2.frame.dispose();
        }
    }
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) BasePretrainNetwork(org.deeplearning4j.nn.layers.BasePretrainNetwork) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) IOException(java.io.IOException) FileInputStream(java.io.FileInputStream) ObjectInputStream(java.io.ObjectInputStream)

Example 22 with MnistDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.

the class DataSetIteratorTest method testBatchSizeOfOneMnist.

@Test
public void testBatchSizeOfOneMnist() throws Exception {
    //MNIST:
    DataSetIterator mnist = new MnistDataSetIterator(1, 5);
    int mnistC = 0;
    while (mnist.hasNext()) {
        mnistC++;
        DataSet ds = mnist.next();
        assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
    }
    assertEquals(5, mnistC);
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) LFWDataSetIterator(org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) CifarDataSetIterator(org.deeplearning4j.datasets.iterator.impl.CifarDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 23 with MnistDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.

the class SamplingTest method testSample.

@Test
public void testSample() throws Exception {
    DataSetIterator iter = new MnistDataSetIterator(10, 10);
    //batch size and total
    DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10);
    assertEquals(sampling.next().numExamples(), 10);
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) Test(org.junit.Test)

Example 24 with MnistDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.

the class CNNProcessorTest method testCNNInputPreProcessorMnist.

@Test
public void testCNNInputPreProcessorMnist() throws Exception {
    int numSamples = 1;
    int batchSize = 1;
    DataSet mnistIter = new MnistDataSetIterator(batchSize, numSamples, true).next();
    MultiLayerNetwork model = getCNNMnistConfig();
    model.init();
    model.fit(mnistIter);
    int val2to4 = model.getLayer(0).input().shape().length;
    assertTrue(val2to4 == 4);
    int val4to4 = model.getLayer(1).input().shape().length;
    assertTrue(val4to4 == 4);
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Example 25 with MnistDataSetIterator

use of org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator in project deeplearning4j by deeplearning4j.

the class ConvolutionLayerTest method getMnistData.

public INDArray getMnistData() throws Exception {
    int inputWidth = 28;
    int inputHeight = 28;
    int nChannelsIn = 1;
    int nExamples = 5;
    DataSetIterator data = new MnistDataSetIterator(nExamples, nExamples);
    DataSet mnist = data.next();
    nExamples = mnist.numExamples();
    return mnist.getFeatureMatrix().reshape(nExamples, nChannelsIn, inputHeight, inputWidth);
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)

Aggregations

MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)44 Test (org.junit.Test)41 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)40 DataSet (org.nd4j.linalg.dataset.DataSet)31 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)26 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)22 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)15 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)12 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)11 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)11 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)10 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)10 Evaluation (org.deeplearning4j.eval.Evaluation)7 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)7 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)6 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)6 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)6 File (java.io.File)4 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)4