Search in sources :

Example 1 with CompositeReconstructionDistribution

use of org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution in project deeplearning4j by deeplearning4j.

the class VariationalAutoencoder method reconstructionError.

/**
     * Return the reconstruction error for this variational autoencoder.<br>
     * <b>NOTE (important):</b> This method is used ONLY for VAEs that have a standard neural network loss function (i.e.,
     * an {@link org.nd4j.linalg.lossfunctions.ILossFunction} instance such as mean squared error) instead of using a
     * probabilistic reconstruction distribution P(x|z) for the reconstructions (as presented in the VAE architecture by
     * Kingma and Welling).<br>
     * You can check if the VAE has a loss function using {@link #hasLossFunction()}<br>
     * Consequently, the reconstruction error is a simple deterministic function (no Monte-Carlo sampling is required,
     * unlike {@link #reconstructionProbability(INDArray, int)} and {@link #reconstructionLogProbability(INDArray, int)})
     *
     * @param data       The data to calculate the reconstruction error on
     * @return Column vector of reconstruction errors for each example (shape: [numExamples,1])
     */
public INDArray reconstructionError(INDArray data) {
    if (!hasLossFunction()) {
        throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is " + "configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction " + "distribution, use the reconstructionProbability or reconstructionLogProbability methods");
    }
    INDArray pZXMean = activate(data, false);
    //Not probabilistic -> "mean" == output
    INDArray reconstruction = generateAtMeanGivenZ(pZXMean);
    if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
        CompositeReconstructionDistribution c = (CompositeReconstructionDistribution) reconstructionDistribution;
        return c.computeLossFunctionScoreArray(data, reconstruction);
    } else {
        LossFunctionWrapper lfw = (LossFunctionWrapper) reconstructionDistribution;
        ILossFunction lossFunction = lfw.getLossFunction();
        // so we don't want to apply it again. i.e., we are passing the output, not the pre-output.
        return lossFunction.computeScoreArray(data, reconstruction, new ActivationIdentity(), null);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ActivationIdentity(org.nd4j.linalg.activations.impl.ActivationIdentity) CompositeReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution) ILossFunction(org.nd4j.linalg.lossfunctions.ILossFunction) LossFunctionWrapper(org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)

Aggregations

CompositeReconstructionDistribution (org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution)1 LossFunctionWrapper (org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)1 ActivationIdentity (org.nd4j.linalg.activations.impl.ActivationIdentity)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ILossFunction (org.nd4j.linalg.lossfunctions.ILossFunction)1