Search in sources :

Example 6 with WeightInit

use of org.deeplearning4j.nn.weights.WeightInit in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoderParamInitializer method init.

@Override
public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
    if (paramsView.length() != numParams(conf)) {
        throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(conf) + ", got length " + paramsView.length());
    }
    Map<String, INDArray> ret = new LinkedHashMap<>();
    VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
    int nIn = layer.getNIn();
    int nOut = layer.getNOut();
    int[] encoderLayerSizes = layer.getEncoderLayerSizes();
    int[] decoderLayerSizes = layer.getDecoderLayerSizes();
    WeightInit weightInit = layer.getWeightInit();
    Distribution dist = Distributions.createDistribution(layer.getDist());
    int soFar = 0;
    for (int i = 0; i < encoderLayerSizes.length; i++) {
        int encoderLayerNIn;
        if (i == 0) {
            encoderLayerNIn = nIn;
        } else {
            encoderLayerNIn = encoderLayerSizes[i - 1];
        }
        int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
        INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + weightParamCount));
        soFar += weightParamCount;
        INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
        soFar += encoderLayerSizes[i];
        INDArray layerWeights = createWeightMatrix(encoderLayerNIn, encoderLayerSizes[i], weightInit, dist, weightView, initializeParams);
        //TODO don't hardcode 0
        INDArray layerBiases = createBias(encoderLayerSizes[i], 0.0, biasView, initializeParams);
        String sW = "e" + i + WEIGHT_KEY_SUFFIX;
        String sB = "e" + i + BIAS_KEY_SUFFIX;
        ret.put(sW, layerWeights);
        ret.put(sB, layerBiases);
        conf.addVariable(sW);
        conf.addVariable(sB);
    }
    //Last encoder layer -> p(z|x)
    int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
    INDArray pzxWeightsMean = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
    soFar += nWeightsPzx;
    INDArray pzxBiasMean = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
    soFar += nOut;
    INDArray pzxWeightsMeanReshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, dist, pzxWeightsMean, initializeParams);
    //TODO don't hardcode 0
    INDArray pzxBiasMeanReshaped = createBias(nOut, 0.0, pzxBiasMean, initializeParams);
    ret.put(PZX_MEAN_W, pzxWeightsMeanReshaped);
    ret.put(PZX_MEAN_B, pzxBiasMeanReshaped);
    conf.addVariable(PZX_MEAN_W);
    conf.addVariable(PZX_MEAN_B);
    //Pretrain params
    INDArray pzxWeightsLogStdev2 = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
    soFar += nWeightsPzx;
    INDArray pzxBiasLogStdev2 = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
    soFar += nOut;
    INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, dist, pzxWeightsLogStdev2, initializeParams);
    //TODO don't hardcode 0
    INDArray pzxBiasLogStdev2Reshaped = createBias(nOut, 0.0, pzxBiasLogStdev2, initializeParams);
    ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
    ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2Reshaped);
    conf.addVariable(PZX_LOGSTD2_W);
    conf.addVariable(PZX_LOGSTD2_B);
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        int decoderLayerNIn;
        if (i == 0) {
            decoderLayerNIn = nOut;
        } else {
            decoderLayerNIn = decoderLayerSizes[i - 1];
        }
        int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
        INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + weightParamCount));
        soFar += weightParamCount;
        INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
        soFar += decoderLayerSizes[i];
        INDArray layerWeights = createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], weightInit, dist, weightView, initializeParams);
        //TODO don't hardcode 0
        INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, initializeParams);
        String sW = "d" + i + WEIGHT_KEY_SUFFIX;
        String sB = "d" + i + BIAS_KEY_SUFFIX;
        ret.put(sW, layerWeights);
        ret.put(sB, layerBiases);
        conf.addVariable(sW);
        conf.addVariable(sB);
    }
    //Finally, p(x|z):
    int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
    int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
    INDArray pxzWeightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
    soFar += pxzWeightCount;
    INDArray pxzBiasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nDistributionParams));
    INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, weightInit, dist, pxzWeightView, initializeParams);
    //TODO don't hardcode 0
    INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, initializeParams);
    ret.put(PXZ_W, pxzWeightsReshaped);
    ret.put(PXZ_B, pxzBiasReshaped);
    conf.addVariable(PXZ_W);
    conf.addVariable(PXZ_B);
    return ret;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) WeightInit(org.deeplearning4j.nn.weights.WeightInit) Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) LinkedHashMap(java.util.LinkedHashMap)

Aggregations

WeightInit (org.deeplearning4j.nn.weights.WeightInit)6 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)2 Layer (org.deeplearning4j.nn.conf.layers.Layer)2 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)2 LinkedHashMap (java.util.LinkedHashMap)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 Updater (org.deeplearning4j.nn.conf.Updater)1 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)1 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 HiddenUnit (org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit)1 VisibleUnit (org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit)1 PoolingType (org.deeplearning4j.nn.conf.layers.SubsamplingLayer.PoolingType)1 VariationalAutoencoder (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)1 InvalidKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException)1 UnsupportedKerasConfigurationException (org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException)1