use of org.deeplearning4j.nn.conf.layers.BatchNormalization in project deeplearning4j by deeplearning4j.
the class BatchNormalizationTest method getLayer.
protected static Layer getLayer(int nOut, double epsilon, boolean lockGammaBeta, double gamma, double beta) {
BatchNormalization.Builder b = new BatchNormalization.Builder().nOut(nOut).eps(epsilon);
if (lockGammaBeta) {
b.lockGammaBeta(true).gamma(gamma).beta(beta);
}
BatchNormalization bN = b.build();
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).layer(bN).build();
int numParams = conf.getLayer().initializer().numParams(conf);
INDArray params = null;
if (numParams > 0) {
params = Nd4j.create(1, numParams);
}
Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true);
if (numParams > 0) {
layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams));
}
return layer;
}
use of org.deeplearning4j.nn.conf.layers.BatchNormalization in project deeplearning4j by deeplearning4j.
the class BatchNormalizationParamInitializer method getGradientsFromFlattened.
@Override
public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
BatchNormalization layer = (BatchNormalization) conf.getLayer();
int nOut = layer.getNOut();
Map<String, INDArray> out = new LinkedHashMap<>();
int meanOffset = 0;
if (!layer.isLockGammaBeta()) {
INDArray gammaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut));
INDArray betaView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut));
out.put(GAMMA, gammaView);
out.put(BETA, betaView);
meanOffset = 2 * nOut;
}
out.put(GLOBAL_MEAN, gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut)));
out.put(GLOBAL_VAR, gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut)));
return out;
}
use of org.deeplearning4j.nn.conf.layers.BatchNormalization in project deeplearning4j by deeplearning4j.
the class BatchNormalizationParamInitializer method init.
@Override
public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramView, boolean initializeParams) {
Map<String, INDArray> params = Collections.synchronizedMap(new LinkedHashMap<String, INDArray>());
// TODO setup for RNN
BatchNormalization layer = (BatchNormalization) conf.getLayer();
int nOut = layer.getNOut();
int meanOffset = 0;
if (!layer.isLockGammaBeta()) {
//No gamma/beta parameters when gamma/beta are locked
INDArray gammaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut));
INDArray betaView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut));
params.put(GAMMA, createGamma(conf, gammaView, initializeParams));
conf.addVariable(GAMMA);
params.put(BETA, createBeta(conf, betaView, initializeParams));
conf.addVariable(BETA);
meanOffset = 2 * nOut;
}
INDArray globalMeanView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset, meanOffset + nOut));
INDArray globalVarView = paramView.get(NDArrayIndex.point(0), NDArrayIndex.interval(meanOffset + nOut, meanOffset + 2 * nOut));
if (initializeParams) {
globalMeanView.assign(0);
globalVarView.assign(1);
}
params.put(GLOBAL_MEAN, globalMeanView);
conf.addVariable(GLOBAL_MEAN);
params.put(GLOBAL_VAR, globalVarView);
conf.addVariable(GLOBAL_VAR);
return params;
}
Aggregations