use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class RBMTests method testGradient.
@Test
public void testGradient() {
float[][] data = new float[][] { { 1, 1, 1, 0, 0, 0 }, { 1, 0, 1, 0, 0, 0 }, { 1, 1, 1, 0, 0, 0 }, { 0, 0, 1, 1, 1, 0 }, { 0, 0, 1, 1, 0, 0 }, { 0, 0, 1, 1, 1, 0 }, { 0, 0, 1, 1, 1, 0 } };
INDArray input = Nd4j.create(data);
INDArray params = Nd4j.create(1, 6 * 4 + 6 + 4);
RBM rbm = getRBMLayer(6, 4, HiddenUnit.BINARY, VisibleUnit.BINARY, params, true, false, 1, LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
rbm.fit(input);
double value = rbm.score();
Gradient grad2 = rbm.gradient();
assertEquals(24, grad2.getGradientFor("W").length());
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class BatchNormalizationTest method testCnnForwardPass.
@Test
public void testCnnForwardPass() {
int nOut = 10;
Layer l = getLayer(nOut, 0.0, false, -1, -1);
//Gamma, beta, global mean, global var
assertEquals(4 * nOut, l.numParams());
int hw = 15;
Nd4j.getRandom().setSeed(12345);
INDArray randInput = Nd4j.rand(12345, 100, nOut, hw, hw);
INDArray output = l.activate(randInput, true);
assertEquals(4, output.rank());
INDArray mean = output.mean(0, 2, 3);
INDArray stdev = output.std(false, 0, 2, 3);
assertArrayEquals(new float[nOut], mean.data().asFloat(), 1e-6f);
assertArrayEquals(Nd4j.ones(1, nOut).data().asFloat(), stdev.data().asFloat(), 1e-6f);
//If we fix gamma/beta: expect different mean and variance...
double gamma = 2.0;
double beta = 3.0;
l = getLayer(nOut, 0.0, true, gamma, beta);
//Should have only global mean/var parameters
assertEquals(2 * nOut, l.numParams());
output = l.activate(randInput, true);
mean = output.mean(0, 2, 3);
stdev = output.std(false, 0, 2, 3);
assertEquals(Nd4j.valueArrayOf(mean.shape(), beta), mean);
assertEquals(Nd4j.valueArrayOf(stdev.shape(), gamma), stdev);
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class BatchNormalizationTest method checkMeanVarianceEstimateCNN.
@Test
public void checkMeanVarianceEstimateCNN() throws Exception {
Nd4j.getRandom().setSeed(12345);
//Check that the internal global mean/variance estimate is approximately correct
//First, Mnist data as 2d input (NOT taking into account convolution property)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nOut(10).build()).backprop(true).pretrain(false).setInputType(InputType.convolutional(5, 5, 3)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
int minibatch = 32;
List<DataSet> list = new ArrayList<>();
for (int i = 0; i < 100; i++) {
list.add(new DataSet(Nd4j.rand(new int[] { minibatch, 3, 5, 5 }), Nd4j.rand(minibatch, 10)));
}
DataSetIterator iter = new ListDataSetIterator(list);
INDArray expMean = Nd4j.valueArrayOf(new int[] { 1, 3 }, 0.5);
//Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833
INDArray expVar = Nd4j.valueArrayOf(new int[] { 1, 3 }, 1 / 12.0);
for (int i = 0; i < 10; i++) {
iter.reset();
net.fit(iter);
}
INDArray estMean = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
INDArray estVar = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
float[] fMeanExp = expMean.data().asFloat();
float[] fMeanAct = estMean.data().asFloat();
float[] fVarExp = expVar.data().asFloat();
float[] fVarAct = estVar.data().asFloat();
// System.out.println("Mean vs. estimated mean:");
// System.out.println(Arrays.toString(fMeanExp));
// System.out.println(Arrays.toString(fMeanAct));
//
// System.out.println("Var vs. estimated var:");
// System.out.println(Arrays.toString(fVarExp));
// System.out.println(Arrays.toString(fVarAct));
assertArrayEquals(fMeanExp, fMeanAct, 0.01f);
assertArrayEquals(fVarExp, fVarAct, 0.01f);
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class BatchNormalizationTest method checkSerialization.
@Test
public void checkSerialization() throws Exception {
//Serialize the batch norm network (after training), and make sure we get same activations out as before
// i.e., make sure state is properly stored
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(2).seed(12345).list().layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).build()).layer(1, new BatchNormalization.Builder().build()).layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build()).layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build()).layer(4, new BatchNormalization.Builder().build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
for (int i = 0; i < 20; i++) {
net.fit(iter.next());
}
INDArray in = iter.next().getFeatureMatrix();
INDArray out = net.output(in, false);
INDArray out2 = net.output(in, false);
assertEquals(out, out2);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
baos.close();
byte[] bArr = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bArr);
MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(bais, true);
INDArray outDeser = net2.output(in, false);
assertEquals(out, outDeser);
}
use of org.nd4j.linalg.api.ndarray.INDArray in project deeplearning4j by deeplearning4j.
the class BatchNormalizationTest method checkMeanVarianceEstimate.
@Test
public void checkMeanVarianceEstimate() throws Exception {
Nd4j.getRandom().setSeed(12345);
//Check that the internal global mean/variance estimate is approximately correct
//First, Mnist data as 2d input (NOT taking into account convolution property)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.RMSPROP).seed(12345).list().layer(0, new BatchNormalization.Builder().nIn(10).nOut(10).eps(1e-5).decay(0.95).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).nIn(10).nOut(10).build()).backprop(true).pretrain(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
int minibatch = 32;
List<DataSet> list = new ArrayList<>();
for (int i = 0; i < 200; i++) {
list.add(new DataSet(Nd4j.rand(minibatch, 10), Nd4j.rand(minibatch, 10)));
}
DataSetIterator iter = new ListDataSetIterator(list);
INDArray expMean = Nd4j.valueArrayOf(new int[] { 1, 10 }, 0.5);
//Expected variance of U(0,1) distribution: 1/12 * (1-0)^2 = 0.0833
INDArray expVar = Nd4j.valueArrayOf(new int[] { 1, 10 }, 1 / 12.0);
for (int i = 0; i < 10; i++) {
iter.reset();
net.fit(iter);
}
INDArray estMean = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
INDArray estVar = net.getLayer(0).getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
float[] fMeanExp = expMean.data().asFloat();
float[] fMeanAct = estMean.data().asFloat();
float[] fVarExp = expVar.data().asFloat();
float[] fVarAct = estVar.data().asFloat();
// System.out.println("Mean vs. estimated mean:");
// System.out.println(Arrays.toString(fMeanExp));
// System.out.println(Arrays.toString(fMeanAct));
//
// System.out.println("Var vs. estimated var:");
// System.out.println(Arrays.toString(fVarExp));
// System.out.println(Arrays.toString(fVarAct));
assertArrayEquals(fMeanExp, fMeanAct, 0.02f);
assertArrayEquals(fVarExp, fVarAct, 0.02f);
}
Aggregations