Search in sources :

Example 1 with LossFunctionWrapper

use of org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method reconstructionLogProbability.

/**
     * Return the log reconstruction probability given the specified number of samples.<br>
     * See {@link #reconstructionLogProbability(INDArray, int)} for more details
     *
     * @param data       The data to calculate the log reconstruction probability
     * @param numSamples Number of samples with which to base the reconstruction probability on.
     * @return Column vector of reconstruction log probabilities for each example (shape: [numExamples,1])
     */
public INDArray reconstructionLogProbability(INDArray data, int numSamples) {
    if (numSamples <= 0) {
        throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + numSamples);
    }
    if (reconstructionDistribution instanceof LossFunctionWrapper) {
        throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using " + "a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction " + "instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability");
    }
    //Forward pass through the encoder and mean for P(Z|X)
    setInput(data);
    VAEFwdHelper fwd = doForward(true, true);
    IActivation afn = conf().getLayer().getActivationFn();
    //Forward pass through logStd^2 for P(Z|X)
    INDArray pzxLogStd2W = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
    INDArray pzxLogStd2b = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
    INDArray meanZ = fwd.pzxMeanPreOut;
    INDArray logStdev2Z = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
    pzxActivationFn.getActivation(meanZ, false);
    pzxActivationFn.getActivation(logStdev2Z, false);
    INDArray pzxSigma = Transforms.exp(logStdev2Z, false);
    Transforms.sqrt(pzxSigma, false);
    int minibatch = input.size(0);
    int size = fwd.pzxMeanPreOut.size(1);
    INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
    INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
    INDArray[] decoderWeights = new INDArray[decoderLayerSizes.length];
    INDArray[] decoderBiases = new INDArray[decoderLayerSizes.length];
    for (int i = 0; i < decoderLayerSizes.length; i++) {
        String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
        String bKey = "d" + i + BIAS_KEY_SUFFIX;
        decoderWeights[i] = params.get(wKey);
        decoderBiases[i] = params.get(bKey);
    }
    INDArray sumReconstructionNegLogProbability = null;
    for (int i = 0; i < numSamples; i++) {
        INDArray e = Nd4j.randn(minibatch, size);
        //z = mu + sigma * e, with e ~ N(0,1)
        INDArray z = e.muli(pzxSigma).addi(meanZ);
        //Do forward pass through decoder
        int nDecoderLayers = decoderLayerSizes.length;
        INDArray currentActivations = z;
        for (int j = 0; j < nDecoderLayers; j++) {
            currentActivations = currentActivations.mmul(decoderWeights[j]).addiRowVector(decoderBiases[j]);
            afn.getActivation(currentActivations, false);
        }
        //And calculate reconstruction distribution preOut
        INDArray pxzDistributionPreOut = currentActivations.mmul(pxzw).addiRowVector(pxzb);
        if (i == 0) {
            sumReconstructionNegLogProbability = reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut);
        } else {
            sumReconstructionNegLogProbability.addi(reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut));
        }
    }
    setInput(null);
    return sumReconstructionNegLogProbability.divi(-numSamples);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) IActivation(org.nd4j.linalg.activations.IActivation) LossFunctionWrapper(org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)

Example 2 with LossFunctionWrapper

use of org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method reconstructionError.

/**
     * Return the reconstruction error for this variational autoencoder.<br>
     * <b>NOTE (important):</b> This method is used ONLY for VAEs that have a standard neural network loss function (i.e.,
     * an {@link org.nd4j.linalg.lossfunctions.ILossFunction} instance such as mean squared error) instead of using a
     * probabilistic reconstruction distribution P(x|z) for the reconstructions (as presented in the VAE architecture by
     * Kingma and Welling).<br>
     * You can check if the VAE has a loss function using {@link #hasLossFunction()}<br>
     * Consequently, the reconstruction error is a simple deterministic function (no Monte-Carlo sampling is required,
     * unlike {@link #reconstructionProbability(INDArray, int)} and {@link #reconstructionLogProbability(INDArray, int)})
     *
     * @param data       The data to calculate the reconstruction error on
     * @return Column vector of reconstruction errors for each example (shape: [numExamples,1])
     */
public INDArray reconstructionError(INDArray data) {
    if (!hasLossFunction()) {
        throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is " + "configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction " + "distribution, use the reconstructionProbability or reconstructionLogProbability methods");
    }
    INDArray pZXMean = activate(data, false);
    //Not probabilistic -> "mean" == output
    INDArray reconstruction = generateAtMeanGivenZ(pZXMean);
    if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
        CompositeReconstructionDistribution c = (CompositeReconstructionDistribution) reconstructionDistribution;
        return c.computeLossFunctionScoreArray(data, reconstruction);
    } else {
        LossFunctionWrapper lfw = (LossFunctionWrapper) reconstructionDistribution;
        ILossFunction lossFunction = lfw.getLossFunction();
        // so we don't want to apply it again. i.e., we are passing the output, not the pre-output.
        return lossFunction.computeScoreArray(data, reconstruction, new ActivationIdentity(), null);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ActivationIdentity(org.nd4j.linalg.activations.impl.ActivationIdentity) CompositeReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution) ILossFunction(org.nd4j.linalg.lossfunctions.ILossFunction) LossFunctionWrapper(org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)

Example 3 with LossFunctionWrapper

use of org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper in project deeplearning4j by deeplearning4j.

the class TestMiscFunctions method testVaeReconstructionErrorWithKey.

@Test
public void testVaeReconstructionErrorWithKey() {
    //Simple test. We CAN do a direct comparison here vs. local, as reconstruction error is deterministic
    int nIn = 10;
    MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list().layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().reconstructionDistribution(new LossFunctionWrapper(Activation.IDENTITY, new LossMSE())).nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(mlc);
    net.init();
    VariationalAutoencoder vae = (VariationalAutoencoder) net.getLayer(0);
    List<Tuple2<Integer, INDArray>> toScore = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        INDArray arr = Nd4j.rand(1, nIn);
        toScore.add(new Tuple2<Integer, INDArray>(i, arr));
    }
    JavaPairRDD<Integer, INDArray> rdd = sc.parallelizePairs(toScore);
    JavaPairRDD<Integer, Double> reconstrErrors = rdd.mapPartitionsToPair(new VaeReconstructionErrorWithKeyFunction<Integer>(sc.broadcast(net.params()), sc.broadcast(mlc.toJson()), 16));
    Map<Integer, Double> l = reconstrErrors.collectAsMap();
    assertEquals(100, l.size());
    for (int i = 0; i < 100; i++) {
        assertTrue(l.containsKey(i));
        INDArray localToScore = toScore.get(i)._2();
        double localScore = vae.reconstructionError(localToScore).data().asDouble()[0];
        assertEquals(localScore, l.get(i), 1e-6);
    }
}
Also used : VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LossMSE(org.nd4j.linalg.lossfunctions.impl.LossMSE) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) LossFunctionWrapper(org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

LossFunctionWrapper (org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 CompositeReconstructionDistribution (org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution)1 VariationalAutoencoder (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1 IActivation (org.nd4j.linalg.activations.IActivation)1 ActivationIdentity (org.nd4j.linalg.activations.impl.ActivationIdentity)1 ILossFunction (org.nd4j.linalg.lossfunctions.ILossFunction)1 LossMSE (org.nd4j.linalg.lossfunctions.impl.LossMSE)1 Tuple2 (scala.Tuple2)1