use of org.deeplearning4j.nn.conf.distribution.GaussianDistribution in project deeplearning4j by deeplearning4j.
the class LayerConfigValidationTest method testWeightInitDistNotSet.
@Test
public void testWeightInitDistNotSet() {
// Warning thrown only since global dist can be set with a different weight init locally
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).dist(new GaussianDistribution(1e-3, 2)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
use of org.deeplearning4j.nn.conf.distribution.GaussianDistribution in project deeplearning4j by deeplearning4j.
the class GradientCheckTestsComputationGraph method testBasicCenterLoss.
@Test
public void testBasicCenterLoss() {
Nd4j.getRandom().setSeed(12345);
int numLabels = 2;
boolean[] trainFirst = new boolean[] { false, true };
for (boolean train : trainFirst) {
for (double lambda : new double[] { 0.0, 0.5, 2.0 }) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.DISTRIBUTION).dist(new GaussianDistribution(0, 1)).updater(Updater.NONE).learningRate(1.0).graphBuilder().addInputs("input1").addLayer("l1", new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build(), "input1").addLayer("cl", new CenterLossOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(numLabels).alpha(1.0).lambda(lambda).gradientCheck(true).activation(Activation.SOFTMAX).build(), "l1").setOutputs("cl").pretrain(false).backprop(true).build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
INDArray example = Nd4j.rand(150, 4);
INDArray labels = Nd4j.zeros(150, numLabels);
Random r = new Random(12345);
for (int i = 0; i < 150; i++) {
labels.putScalar(i, r.nextInt(numLabels), 1.0);
}
if (train) {
for (int i = 0; i < 10; i++) {
INDArray f = Nd4j.rand(10, 4);
INDArray l = Nd4j.zeros(10, numLabels);
for (int j = 0; j < 10; j++) {
l.putScalar(j, r.nextInt(numLabels), 1.0);
}
graph.fit(new INDArray[] { f }, new INDArray[] { l });
}
}
String msg = "testBasicCenterLoss() - lambda = " + lambda + ", trainFirst = " + train;
if (PRINT_RESULTS) {
System.out.println(msg);
for (int j = 0; j < graph.getNumLayers(); j++) System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] { example }, new INDArray[] { labels });
assertTrue(msg, gradOK);
}
}
}
Aggregations