Search in sources :

Example 1 with MnistDataFetcher

use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.

the class RBMTests method testMnist.

@Test
public void testMnist() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().iterations(30).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder().nIn(784).nOut(600).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(1, 1e-5)).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()).build();
    conf.setPretrain(true);
    org.deeplearning4j.nn.conf.layers.RBM layerConf = (org.deeplearning4j.nn.conf.layers.RBM) conf.getLayer();
    fetcher.fetch(10);
    DataSet d2 = fetcher.next();
    org.nd4j.linalg.api.rng.distribution.Distribution dist = Nd4j.getDistributions().createNormal(1, 1e-5);
    System.out.println(dist.sample(new int[] { layerConf.getNIn(), layerConf.getNOut() }));
    INDArray input = d2.getFeatureMatrix();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    RBM rbm = (RBM) conf.getLayer().instantiate(conf, null, 0, params, true);
    rbm.fit(input);
}
Also used : Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MnistDataFetcher(org.deeplearning4j.datasets.fetchers.MnistDataFetcher) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NormalDistribution(org.deeplearning4j.nn.conf.distribution.NormalDistribution) Test(org.junit.Test)

Example 2 with MnistDataFetcher

use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.

the class TestRenders method renderHistogram.

@Test
public void renderHistogram() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(100).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600).corruptionLevel(0.6).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()).build();
    fetcher.fetch(100);
    DataSet d2 = fetcher.next();
    INDArray input = d2.getFeatureMatrix();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, null, 0, params, true);
    da.setListeners(new ScoreIterationListener(1), new HistogramIterationListener(5));
    da.setParams(da.params());
    da.fit(input);
}
Also used : MnistDataFetcher(org.deeplearning4j.datasets.fetchers.MnistDataFetcher) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) HistogramIterationListener(org.deeplearning4j.ui.weights.HistogramIterationListener) AutoEncoder(org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 3 with MnistDataFetcher

use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.

the class TestRenders method renderHistogram2.

@Test
public void renderHistogram2() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1000).learningRate(1e-1f).list().layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(784).nOut(100).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(10).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(Arrays.<IterationListener>asList(new ScoreIterationListener(1), new HistogramIterationListener(1, true)));
    fetcher.fetch(100);
    DataSet d2 = fetcher.next();
    net.fit(d2);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) HistogramIterationListener(org.deeplearning4j.ui.weights.HistogramIterationListener) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MnistDataFetcher(org.deeplearning4j.datasets.fetchers.MnistDataFetcher) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 4 with MnistDataFetcher

use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.

the class DataSets method mnist.

public static DataSet mnist(int num) {
    try {
        MnistDataFetcher fetcher = new MnistDataFetcher();
        fetcher.fetch(num);
        return fetcher.next();
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
}
Also used : MnistDataFetcher(org.deeplearning4j.datasets.fetchers.MnistDataFetcher) IOException(java.io.IOException)

Example 5 with MnistDataFetcher

use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.

the class AutoEncoderTest method testBackProp.

@Test
public void testBackProp() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    //        LayerFactory layerFactory = LayerFactories.getFactory(new org.deeplearning4j.nn.conf.layers.AutoEncoder());
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(100).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600).corruptionLevel(0.6).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()).build();
    fetcher.fetch(100);
    DataSet d2 = fetcher.next();
    INDArray input = d2.getFeatureMatrix();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, null, 0, params, true);
    Gradient g = new DefaultGradient();
    g.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, da.decode(da.activate(input)).sub(input));
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) DefaultGradient(org.deeplearning4j.nn.gradient.DefaultGradient) MnistDataFetcher(org.deeplearning4j.datasets.fetchers.MnistDataFetcher) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Test(org.junit.Test)

Aggregations

MnistDataFetcher (org.deeplearning4j.datasets.fetchers.MnistDataFetcher)6 Test (org.junit.Test)5 DataSet (org.nd4j.linalg.dataset.DataSet)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)3 HistogramIterationListener (org.deeplearning4j.ui.weights.HistogramIterationListener)2 IOException (java.io.IOException)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)1 DefaultGradient (org.deeplearning4j.nn.gradient.DefaultGradient)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 AutoEncoder (org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 IterationListener (org.deeplearning4j.optimize.api.IterationListener)1 Distribution (org.nd4j.linalg.api.rng.distribution.Distribution)1