Search in sources :

Example 1 with VariationalAutoencoder

use of org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoderParamInitializer method getGradientsFromFlattened.

@Override
public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
    Map<String, INDArray> ret = new LinkedHashMap<>();
    VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
    int nIn = layer.getNIn();
    int nOut = layer.getNOut();
    int[] encoderLayerSizes = layer.getEncoderLayerSizes();
    int[] decoderLayerSizes = layer.getDecoderLayerSizes();
    int soFar = 0;
    for (int i = 0; i < encoderLayerSizes.length; i++) {
        int encoderLayerNIn;
        if (i == 0) {
            encoderLayerNIn = nIn;
        } else {
            encoderLayerNIn = encoderLayerSizes[i - 1];
        }
        int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
        INDArray weightGradView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + weightParamCount));
        soFar += weightParamCount;
        INDArray biasGradView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
        soFar += encoderLayerSizes[i];
        INDArray layerWeights = weightGradView.reshape('f', encoderLayerNIn, encoderLayerSizes[i]);
        //Aready correct shape (row vector)
        INDArray layerBiases = biasGradView;
        ret.put("e" + i + WEIGHT_KEY_SUFFIX, layerWeights);
        ret.put("e" + i + BIAS_KEY_SUFFIX, layerBiases);
    }
    //Last encoder layer -> p(z|x)
    int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
    INDArray pzxWeightsMean = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
    soFar += nWeightsPzx;
    INDArray pzxBiasMean = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
    soFar += nOut;
    INDArray pzxWeightGradMeanReshaped = pzxWeightsMean.reshape('f', encoderLayerSizes[encoderLayerSizes.length - 1], nOut);
    ret.put(PZX_MEAN_W, pzxWeightGradMeanReshaped);
    ret.put(PZX_MEAN_B, pzxBiasMean);
    ////////////////////////////////////////////////////////
    INDArray pzxWeightsLogStdev2 = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
    soFar += nWeightsPzx;
    INDArray pzxBiasLogStdev2 = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
    soFar += nOut;
    INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, null, null, pzxWeightsLogStdev2, //TODO
    false);
    ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
    ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2);
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        int decoderLayerNIn;
        if (i == 0) {
            decoderLayerNIn = nOut;
        } else {
            decoderLayerNIn = decoderLayerSizes[i - 1];
        }
        int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
        INDArray weightView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + weightParamCount));
        soFar += weightParamCount;
        INDArray biasView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
        soFar += decoderLayerSizes[i];
        INDArray layerWeights = createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], null, null, weightView, false);
        //TODO don't hardcode 0
        INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, false);
        String sW = "d" + i + WEIGHT_KEY_SUFFIX;
        String sB = "d" + i + BIAS_KEY_SUFFIX;
        ret.put(sW, layerWeights);
        ret.put(sB, layerBiases);
    }
    //Finally, p(x|z):
    int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
    int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
    INDArray pxzWeightView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
    soFar += pxzWeightCount;
    INDArray pxzBiasView = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nDistributionParams));
    INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, null, null, pxzWeightView, false);
    INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, false);
    ret.put(PXZ_W, pxzWeightsReshaped);
    ret.put(PXZ_B, pxzBiasReshaped);
    return ret;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) LinkedHashMap(java.util.LinkedHashMap)

Example 2 with VariationalAutoencoder

use of org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoderParamInitializer method init.

@Override
public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
    if (paramsView.length() != numParams(conf)) {
        throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(conf) + ", got length " + paramsView.length());
    }
    Map<String, INDArray> ret = new LinkedHashMap<>();
    VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
    int nIn = layer.getNIn();
    int nOut = layer.getNOut();
    int[] encoderLayerSizes = layer.getEncoderLayerSizes();
    int[] decoderLayerSizes = layer.getDecoderLayerSizes();
    WeightInit weightInit = layer.getWeightInit();
    Distribution dist = Distributions.createDistribution(layer.getDist());
    int soFar = 0;
    for (int i = 0; i < encoderLayerSizes.length; i++) {
        int encoderLayerNIn;
        if (i == 0) {
            encoderLayerNIn = nIn;
        } else {
            encoderLayerNIn = encoderLayerSizes[i - 1];
        }
        int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
        INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + weightParamCount));
        soFar += weightParamCount;
        INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
        soFar += encoderLayerSizes[i];
        INDArray layerWeights = createWeightMatrix(encoderLayerNIn, encoderLayerSizes[i], weightInit, dist, weightView, initializeParams);
        //TODO don't hardcode 0
        INDArray layerBiases = createBias(encoderLayerSizes[i], 0.0, biasView, initializeParams);
        String sW = "e" + i + WEIGHT_KEY_SUFFIX;
        String sB = "e" + i + BIAS_KEY_SUFFIX;
        ret.put(sW, layerWeights);
        ret.put(sB, layerBiases);
        conf.addVariable(sW);
        conf.addVariable(sB);
    }
    //Last encoder layer -> p(z|x)
    int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
    INDArray pzxWeightsMean = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
    soFar += nWeightsPzx;
    INDArray pzxBiasMean = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
    soFar += nOut;
    INDArray pzxWeightsMeanReshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, dist, pzxWeightsMean, initializeParams);
    //TODO don't hardcode 0
    INDArray pzxBiasMeanReshaped = createBias(nOut, 0.0, pzxBiasMean, initializeParams);
    ret.put(PZX_MEAN_W, pzxWeightsMeanReshaped);
    ret.put(PZX_MEAN_B, pzxBiasMeanReshaped);
    conf.addVariable(PZX_MEAN_W);
    conf.addVariable(PZX_MEAN_B);
    //Pretrain params
    INDArray pzxWeightsLogStdev2 = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
    soFar += nWeightsPzx;
    INDArray pzxBiasLogStdev2 = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
    soFar += nOut;
    INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInit, dist, pzxWeightsLogStdev2, initializeParams);
    //TODO don't hardcode 0
    INDArray pzxBiasLogStdev2Reshaped = createBias(nOut, 0.0, pzxBiasLogStdev2, initializeParams);
    ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
    ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2Reshaped);
    conf.addVariable(PZX_LOGSTD2_W);
    conf.addVariable(PZX_LOGSTD2_B);
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        int decoderLayerNIn;
        if (i == 0) {
            decoderLayerNIn = nOut;
        } else {
            decoderLayerNIn = decoderLayerSizes[i - 1];
        }
        int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
        INDArray weightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + weightParamCount));
        soFar += weightParamCount;
        INDArray biasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
        soFar += decoderLayerSizes[i];
        INDArray layerWeights = createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], weightInit, dist, weightView, initializeParams);
        //TODO don't hardcode 0
        INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, initializeParams);
        String sW = "d" + i + WEIGHT_KEY_SUFFIX;
        String sB = "d" + i + BIAS_KEY_SUFFIX;
        ret.put(sW, layerWeights);
        ret.put(sB, layerBiases);
        conf.addVariable(sW);
        conf.addVariable(sB);
    }
    //Finally, p(x|z):
    int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
    int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
    INDArray pxzWeightView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
    soFar += pxzWeightCount;
    INDArray pxzBiasView = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nDistributionParams));
    INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], nDistributionParams, weightInit, dist, pxzWeightView, initializeParams);
    //TODO don't hardcode 0
    INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, initializeParams);
    ret.put(PXZ_W, pxzWeightsReshaped);
    ret.put(PXZ_B, pxzBiasReshaped);
    conf.addVariable(PXZ_W);
    conf.addVariable(PXZ_B);
    return ret;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) WeightInit(org.deeplearning4j.nn.weights.WeightInit) Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) LinkedHashMap(java.util.LinkedHashMap)

Example 3 with VariationalAutoencoder

use of org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoderParamInitializer method numParams.

@Override
public int numParams(NeuralNetConfiguration conf) {
    VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
    int nIn = layer.getNIn();
    int nOut = layer.getNOut();
    int[] encoderLayerSizes = layer.getEncoderLayerSizes();
    int[] decoderLayerSizes = layer.getDecoderLayerSizes();
    int paramCount = 0;
    for (int i = 0; i < encoderLayerSizes.length; i++) {
        int encoderLayerIn;
        if (i == 0) {
            encoderLayerIn = nIn;
        } else {
            encoderLayerIn = encoderLayerSizes[i - 1];
        }
        //weights + bias
        paramCount += (encoderLayerIn + 1) * encoderLayerSizes[i];
    }
    //Between the last encoder layer and the parameters for p(z|x):
    int lastEncLayerSize = encoderLayerSizes[encoderLayerSizes.length - 1];
    //Mean and variance parameters used in unsupervised training
    paramCount += (lastEncLayerSize + 1) * 2 * nOut;
    //Decoder:
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        int decoderLayerNIn;
        if (i == 0) {
            decoderLayerNIn = nOut;
        } else {
            decoderLayerNIn = decoderLayerSizes[i - 1];
        }
        paramCount += (decoderLayerNIn + 1) * decoderLayerSizes[i];
    }
    //Between last decoder layer and parameters for p(x|z):
    int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
    int lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
    paramCount += (lastDecLayerSize + 1) * nDistributionParams;
    return paramCount;
}
Also used : VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)

Example 4 with VariationalAutoencoder

use of org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class TrainModuleUtils method buildGraphInfo.

public static GraphInfo buildGraphInfo(NeuralNetConfiguration config) {
    List<String> vertexNames = new ArrayList<>();
    List<String> originalVertexName = new ArrayList<>();
    List<String> layerTypes = new ArrayList<>();
    List<List<Integer>> layerInputs = new ArrayList<>();
    List<Map<String, String>> layerInfo = new ArrayList<>();
    vertexNames.add("Input");
    originalVertexName.add(null);
    layerTypes.add("Input");
    layerInputs.add(Collections.emptyList());
    layerInfo.add(Collections.emptyMap());
    if (config.getLayer() instanceof VariationalAutoencoder) {
        //Special case like this is a bit ugly - but it works
        VariationalAutoencoder va = (VariationalAutoencoder) config.getLayer();
        int[] encLayerSizes = va.getEncoderLayerSizes();
        int[] decLayerSizes = va.getDecoderLayerSizes();
        int layerIndex = 1;
        for (int i = 0; i < encLayerSizes.length; i++) {
            String name = "encoder_" + i;
            vertexNames.add(name);
            originalVertexName.add("e" + i);
            String layerType = "VAE-Encoder";
            layerTypes.add(layerType);
            layerInputs.add(Collections.singletonList(layerIndex - 1));
            layerIndex++;
            Map<String, String> encoderInfo = new LinkedHashMap<>();
            int inputSize = (i == 0 ? va.getNIn() : encLayerSizes[i - 1]);
            int outputSize = encLayerSizes[i];
            encoderInfo.put("Input Size", String.valueOf(inputSize));
            encoderInfo.put("Layer Size", String.valueOf(outputSize));
            encoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
            encoderInfo.put("Activation Function", va.getActivationFn().toString());
            layerInfo.add(encoderInfo);
        }
        vertexNames.add("z");
        originalVertexName.add(VariationalAutoencoderParamInitializer.PZX_PREFIX);
        layerTypes.add("VAE-LatentVariable");
        layerInputs.add(Collections.singletonList(layerIndex - 1));
        layerIndex++;
        Map<String, String> latentInfo = new LinkedHashMap<>();
        int inputSize = encLayerSizes[encLayerSizes.length - 1];
        int outputSize = va.getNOut();
        latentInfo.put("Input Size", String.valueOf(inputSize));
        latentInfo.put("Layer Size", String.valueOf(outputSize));
        latentInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize * 2));
        latentInfo.put("Activation Function", va.getPzxActivationFn().toString());
        layerInfo.add(latentInfo);
        for (int i = 0; i < decLayerSizes.length; i++) {
            String name = "decoder_" + i;
            vertexNames.add(name);
            originalVertexName.add("d" + i);
            String layerType = "VAE-Decoder";
            layerTypes.add(layerType);
            layerInputs.add(Collections.singletonList(layerIndex - 1));
            layerIndex++;
            Map<String, String> decoderInfo = new LinkedHashMap<>();
            inputSize = (i == 0 ? va.getNOut() : decLayerSizes[i - 1]);
            outputSize = encLayerSizes[i];
            decoderInfo.put("Input Size", String.valueOf(inputSize));
            decoderInfo.put("Layer Size", String.valueOf(outputSize));
            decoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
            decoderInfo.put("Activation Function", va.getActivationFn().toString());
            layerInfo.add(decoderInfo);
        }
        vertexNames.add("x");
        originalVertexName.add(VariationalAutoencoderParamInitializer.PXZ_PREFIX);
        layerTypes.add("VAE-Reconstruction");
        layerInputs.add(Collections.singletonList(layerIndex - 1));
        layerIndex++;
        Map<String, String> reconstructionInfo = new LinkedHashMap<>();
        inputSize = decLayerSizes[decLayerSizes.length - 1];
        outputSize = va.getNIn();
        reconstructionInfo.put("Input Size", String.valueOf(inputSize));
        reconstructionInfo.put("Layer Size", String.valueOf(outputSize));
        reconstructionInfo.put("Num Parameters", String.valueOf((inputSize + 1) * va.getOutputDistribution().distributionInputSize(va.getNIn())));
        reconstructionInfo.put("Distribution", va.getOutputDistribution().toString());
        layerInfo.add(reconstructionInfo);
    } else {
        //RBM or similar...
        Layer layer = config.getLayer();
        String layerName = layer.getLayerName();
        if (layerName == null)
            layerName = "layer0";
        vertexNames.add(layerName);
        originalVertexName.add(String.valueOf("0"));
        String layerType = config.getLayer().getClass().getSimpleName().replaceAll("Layer$", "");
        layerTypes.add(layerType);
        layerInputs.add(Collections.singletonList(0));
        //Extract layer info
        Map<String, String> map = getLayerInfo(config, layer);
        layerInfo.add(map);
    }
    return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName);
}
Also used : VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)

Aggregations

VariationalAutoencoder (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)4 LinkedHashMap (java.util.LinkedHashMap)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 Distribution (org.nd4j.linalg.api.rng.distribution.Distribution)1