use of org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder in project deeplearning4j by deeplearning4j.
the class TestRenders method renderHistogram.
@Test
public void renderHistogram() throws Exception {
MnistDataFetcher fetcher = new MnistDataFetcher(true);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(100).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600).corruptionLevel(0.6).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()).build();
fetcher.fetch(100);
DataSet d2 = fetcher.next();
INDArray input = d2.getFeatureMatrix();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, null, 0, params, true);
da.setListeners(new ScoreIterationListener(1), new HistogramIterationListener(5));
da.setParams(da.params());
da.fit(input);
}
use of org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder in project deeplearning4j by deeplearning4j.
the class TestSerialization method testModelSerde.
@Test
public void testModelSerde() throws Exception {
ObjectMapper mapper = getMapper();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().momentum(0.9f).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1000).learningRate(1e-1f).layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(4).nOut(3).corruptionLevel(0.6).sparsity(0.5).lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build()).build();
DataSet d2 = new IrisDataSetIterator(150, 150).next();
INDArray input = d2.getFeatureMatrix();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = Nd4j.create(1, numParams);
AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf, Arrays.asList(new ScoreIterationListener(1), new HistogramIterationListener(1)), 0, params, true);
da.setInput(input);
ModelAndGradient g = new ModelAndGradient(da);
String json = mapper.writeValueAsString(g);
ModelAndGradient read = mapper.readValue(json, ModelAndGradient.class);
assertEquals(g, read);
}
Aggregations