Search in sources :

Example 6 with GaussianReconstructionDistribution

use of org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution in project deeplearning4j by deeplearning4j.

the class TestPlayUI method testUI_VAE.

@Test
@Ignore
public void testUI_VAE() throws Exception {
    //Variational autoencoder - for unsupervised layerwise pretraining
    StatsStorage ss = new InMemoryStatsStorage();
    UIServer uiServer = UIServer.getInstance();
    uiServer.attach(ss);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(1e-5).list().layer(0, new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(10, 11).decoderLayerSizes(12, 13).weightInit(WeightInit.XAVIER).pzxActivationFunction("identity").reconstructionDistribution(new GaussianReconstructionDistribution()).activation(Activation.LEAKYRELU).updater(Updater.SGD).build()).layer(1, new VariationalAutoencoder.Builder().nIn(3).nOut(3).encoderLayerSizes(7).decoderLayerSizes(8).weightInit(WeightInit.XAVIER).pzxActivationFunction("identity").reconstructionDistribution(new GaussianReconstructionDistribution()).activation(Activation.LEAKYRELU).updater(Updater.SGD).build()).layer(2, new OutputLayer.Builder().nIn(3).nOut(3).build()).pretrain(true).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    for (int i = 0; i < 50; i++) {
        net.fit(iter);
        Thread.sleep(100);
    }
    Thread.sleep(100000);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) UIServer(org.deeplearning4j.ui.api.UIServer) VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) StatsListener(org.deeplearning4j.ui.stats.StatsListener) GaussianReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

GaussianReconstructionDistribution (org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution)6 Test (org.junit.Test)6 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)3 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)2 BernoulliReconstructionDistribution (org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution)2 ExponentialReconstructionDistribution (org.deeplearning4j.nn.conf.layers.variational.ExponentialReconstructionDistribution)2 ReconstructionDistribution (org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution)2 VariationalAutoencoder (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)2 ArrayList (java.util.ArrayList)1 Random (java.util.Random)1 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)1 StatsStorage (org.deeplearning4j.api.storage.StatsStorage)1 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1