use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.
the class RBMTests method testMnist.
@Test
public void testMnist() throws Exception {
MnistDataFetcher fetcher = new MnistDataFetcher(true);
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().iterations(30).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.RBM.Builder().nIn(784).nOut(600).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(1, 1e-5)).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()).build();
conf.setPretrain(true);
org.deeplearning4j.nn.conf.layers.RBM layerConf = (org.deeplearning4j.nn.conf.layers.RBM) conf.getLayer();
fetcher.fetch(10);
DataSet d2 = fetcher.next();
org.nd4j.linalg.api.rng.distribution.Distribution dist = Nd4j.getDistributions().createNormal(1, 1e-5);
System.out.println(dist.sample(new int[] { layerConf.getNIn(), layerConf.getNOut() }));
INDArray input = d2.getFeatureMatrix();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
RBM rbm = (RBM) conf.getLayer().instantiate(conf, null, 0, params, true);
rbm.fit(input);
}
use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.
the class TestRenders method renderHistogram.
@Test
public void renderHistogram() throws Exception {
MnistDataFetcher fetcher = new MnistDataFetcher(true);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(100).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600).corruptionLevel(0.6).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()).build();
fetcher.fetch(100);
DataSet d2 = fetcher.next();
INDArray input = d2.getFeatureMatrix();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, null, 0, params, true);
da.setListeners(new ScoreIterationListener(1), new HistogramIterationListener(5));
da.setParams(da.params());
da.fit(input);
}
use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.
the class TestRenders method renderHistogram2.
@Test
public void renderHistogram2() throws Exception {
MnistDataFetcher fetcher = new MnistDataFetcher(true);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1000).learningRate(1e-1f).list().layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(784).nOut(100).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(10).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(Arrays.<IterationListener>asList(new ScoreIterationListener(1), new HistogramIterationListener(1, true)));
fetcher.fetch(100);
DataSet d2 = fetcher.next();
net.fit(d2);
}
use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.
the class DataSets method mnist.
public static DataSet mnist(int num) {
try {
MnistDataFetcher fetcher = new MnistDataFetcher();
fetcher.fetch(num);
return fetcher.next();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
use of org.deeplearning4j.datasets.fetchers.MnistDataFetcher in project deeplearning4j by deeplearning4j.
the class AutoEncoderTest method testBackProp.
@Test
public void testBackProp() throws Exception {
MnistDataFetcher fetcher = new MnistDataFetcher(true);
// LayerFactory layerFactory = LayerFactories.getFactory(new org.deeplearning4j.nn.conf.layers.AutoEncoder());
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(100).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600).corruptionLevel(0.6).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()).build();
fetcher.fetch(100);
DataSet d2 = fetcher.next();
INDArray input = d2.getFeatureMatrix();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, null, 0, params, true);
Gradient g = new DefaultGradient();
g.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, da.decode(da.activate(input)).sub(input));
}
Aggregations