Search in sources :

Example 1 with GaussianDistribution

use of org.deeplearning4j.nn.conf.distribution.GaussianDistribution in project deeplearning4j by deeplearning4j.

the class LayerConfigValidationTest method testWeightInitDistNotSet.

@Test
public void testWeightInitDistNotSet() {
    // Warning thrown only since global dist can be set with a different weight init locally
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).dist(new GaussianDistribution(1e-3, 2)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
}
Also used : GaussianDistribution(org.deeplearning4j.nn.conf.distribution.GaussianDistribution) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Example 2 with GaussianDistribution

use of org.deeplearning4j.nn.conf.distribution.GaussianDistribution in project deeplearning4j by deeplearning4j.

the class GradientCheckTestsComputationGraph method testBasicCenterLoss.

@Test
public void testBasicCenterLoss() {
    Nd4j.getRandom().setSeed(12345);
    int numLabels = 2;
    boolean[] trainFirst = new boolean[] { false, true };
    for (boolean train : trainFirst) {
        for (double lambda : new double[] { 0.0, 0.5, 2.0 }) {
            ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.DISTRIBUTION).dist(new GaussianDistribution(0, 1)).updater(Updater.NONE).learningRate(1.0).graphBuilder().addInputs("input1").addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build(), "input1").addLayer("cl", new CenterLossOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels).alpha(1.0).lambda(lambda).gradientCheck(true).activation(Activation.SOFTMAX).build(), "l1").setOutputs("cl").pretrain(false).backprop(true).build();
            ComputationGraph graph = new ComputationGraph(conf);
            graph.init();
            INDArray example = Nd4j.rand(150, 4);
            INDArray labels = Nd4j.zeros(150, numLabels);
            Random r = new Random(12345);
            for (int i = 0; i < 150; i++) {
                labels.putScalar(i, r.nextInt(numLabels), 1.0);
            }
            if (train) {
                for (int i = 0; i < 10; i++) {
                    INDArray f = Nd4j.rand(10, 4);
                    INDArray l = Nd4j.zeros(10, numLabels);
                    for (int j = 0; j < 10; j++) {
                        l.putScalar(j, r.nextInt(numLabels), 1.0);
                    }
                    graph.fit(new INDArray[] { f }, new INDArray[] { l });
                }
            }
            String msg = "testBasicCenterLoss() - lambda = " + lambda + ", trainFirst = " + train;
            if (PRINT_RESULTS) {
                System.out.println(msg);
                for (int j = 0; j < graph.getNumLayers(); j++) System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
            }
            boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] { example }, new INDArray[] { labels });
            assertTrue(msg, gradOK);
        }
    }
}
Also used : GaussianDistribution(org.deeplearning4j.nn.conf.distribution.GaussianDistribution) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Random(java.util.Random) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Aggregations

GaussianDistribution (org.deeplearning4j.nn.conf.distribution.GaussianDistribution)2 Test (org.junit.Test)2 Random (java.util.Random)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1