use of org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution in project deeplearning4j by deeplearning4j.
the class VariationalAutoencoder method reconstructionError.
/**
* Return the reconstruction error for this variational autoencoder.<br>
* <b>NOTE (important):</b> This method is used ONLY for VAEs that have a standard neural network loss function (i.e.,
* an {@link org.nd4j.linalg.lossfunctions.ILossFunction} instance such as mean squared error) instead of using a
* probabilistic reconstruction distribution P(x|z) for the reconstructions (as presented in the VAE architecture by
* Kingma and Welling).<br>
* You can check if the VAE has a loss function using {@link #hasLossFunction()}<br>
* Consequently, the reconstruction error is a simple deterministic function (no Monte-Carlo sampling is required,
* unlike {@link #reconstructionProbability(INDArray, int)} and {@link #reconstructionLogProbability(INDArray, int)})
*
* @param data The data to calculate the reconstruction error on
* @return Column vector of reconstruction errors for each example (shape: [numExamples,1])
*/
public INDArray reconstructionError(INDArray data) {
if (!hasLossFunction()) {
throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is " + "configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction " + "distribution, use the reconstructionProbability or reconstructionLogProbability methods");
}
INDArray pZXMean = activate(data, false);
//Not probabilistic -> "mean" == output
INDArray reconstruction = generateAtMeanGivenZ(pZXMean);
if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
CompositeReconstructionDistribution c = (CompositeReconstructionDistribution) reconstructionDistribution;
return c.computeLossFunctionScoreArray(data, reconstruction);
} else {
LossFunctionWrapper lfw = (LossFunctionWrapper) reconstructionDistribution;
ILossFunction lossFunction = lfw.getLossFunction();
// so we don't want to apply it again. i.e., we are passing the output, not the pre-output.
return lossFunction.computeScoreArray(data, reconstruction, new ActivationIdentity(), null);
}
}
Aggregations