Search in sources :

Example 6 with HistogramIterationListener

use of org.deeplearning4j.ui.weights.HistogramIterationListener in project deeplearning4j by deeplearning4j.

the class ManualTests method testHistograms.

@Test
public void testHistograms() throws Exception {
    final int numRows = 28;
    final int numColumns = 28;
    int outputNum = 10;
    int numSamples = 60000;
    int batchSize = 100;
    int iterations = 10;
    int seed = 123;
    int listenerFreq = batchSize / 5;
    log.info("Load data....");
    DataSetIterator iter = new MnistDataSetIterator(batchSize, numSamples, true);
    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0).iterations(iterations).momentum(0.5).momentumAfter(Collections.singletonMap(3, 0.9)).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list().layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(500).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).visibleUnit(RBM.VisibleUnit.BINARY).hiddenUnit(RBM.HiddenUnit.BINARY).build()).layer(1, new RBM.Builder().nIn(500).nOut(250).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).visibleUnit(RBM.VisibleUnit.BINARY).hiddenUnit(RBM.HiddenUnit.BINARY).build()).layer(2, new RBM.Builder().nIn(250).nOut(200).weightInit(WeightInit.XAVIER).lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).visibleUnit(RBM.VisibleUnit.BINARY).hiddenUnit(RBM.HiddenUnit.BINARY).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(200).nOut(outputNum).build()).pretrain(true).backprop(false).build();
    //        UiServer server = UiServer.getInstance();
    //        UiConnectionInfo connectionInfo = server.getConnectionInfo();
    //        connectionInfo.setSessionId("my session here");
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList(new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq), new FlowIterationListener(listenerFreq)));
    log.info("Train model....");
    // achieves end to end pre-training
    model.fit(iter);
    log.info("Evaluate model....");
    Evaluation eval = new Evaluation(outputNum);
    DataSetIterator testIter = new MnistDataSetIterator(100, 10000);
    while (testIter.hasNext()) {
        DataSet testMnist = testIter.next();
        INDArray predict2 = model.output(testMnist.getFeatureMatrix());
        eval.eval(testMnist.getLabels(), predict2);
    }
    log.info(eval.stats());
    log.info("****************Example finished********************");
    fail("Not implemented");
}
Also used : Evaluation(org.deeplearning4j.eval.Evaluation) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) HistogramIterationListener(org.deeplearning4j.ui.weights.HistogramIterationListener) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) FlowIterationListener(org.deeplearning4j.ui.flow.FlowIterationListener) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) LFWDataSetIterator(org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) Test(org.junit.Test)

Aggregations

ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)6 HistogramIterationListener (org.deeplearning4j.ui.weights.HistogramIterationListener)6 Test (org.junit.Test)6 DataSet (org.nd4j.linalg.dataset.DataSet)4 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)3 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)3 MnistDataFetcher (org.deeplearning4j.datasets.fetchers.MnistDataFetcher)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 AutoEncoder (org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 LFWDataSetIterator (org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator)1 Evaluation (org.deeplearning4j.eval.Evaluation)1 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1