Search in sources :

Example 1 with AutoEncoder

use of org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder in project deeplearning4j by deeplearning4j.

the class TestRenders method renderHistogram.

@Test
public void renderHistogram() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(100).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600).corruptionLevel(0.6).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()).build();
    fetcher.fetch(100);
    DataSet d2 = fetcher.next();
    INDArray input = d2.getFeatureMatrix();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, null, 0, params, true);
    da.setListeners(new ScoreIterationListener(1), new HistogramIterationListener(5));
    da.setParams(da.params());
    da.fit(input);
}
Also used : MnistDataFetcher(org.deeplearning4j.datasets.fetchers.MnistDataFetcher) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) HistogramIterationListener(org.deeplearning4j.ui.weights.HistogramIterationListener) AutoEncoder(org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) Test(org.junit.Test)

Example 2 with AutoEncoder

use of org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder in project deeplearning4j by deeplearning4j.

the class TestSerialization method testModelSerde.

@Test
public void testModelSerde() throws Exception {
    ObjectMapper mapper = getMapper();
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1000).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.6).sparsity(0.5).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()).build();
    DataSet d2 = new IrisDataSetIterator(150, 150).next();
    INDArray input = d2.getFeatureMatrix();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, Arrays.asList(new ScoreIterationListener(1), new HistogramIterationListener(1)), 0, params, true);
    da.setInput(input);
    ModelAndGradient g = new ModelAndGradient(da);
    String json = mapper.writeValueAsString(g);
    ModelAndGradient read = mapper.readValue(json, ModelAndGradient.class);
    assertEquals(g, read);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) ModelAndGradient(org.deeplearning4j.ui.weights.ModelAndGradient) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) HistogramIterationListener(org.deeplearning4j.ui.weights.HistogramIterationListener) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AutoEncoder(org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) Test(org.junit.Test)

Aggregations

NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 AutoEncoder (org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder)2 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)2 HistogramIterationListener (org.deeplearning4j.ui.weights.HistogramIterationListener)2 Test (org.junit.Test)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 DataSet (org.nd4j.linalg.dataset.DataSet)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 MnistDataFetcher (org.deeplearning4j.datasets.fetchers.MnistDataFetcher)1 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 ModelAndGradient (org.deeplearning4j.ui.weights.ModelAndGradient)1