use of org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper in project deeplearning4j by deeplearning4j.
the class VariationalAutoencoder method reconstructionLogProbability.
/**
* Return the log reconstruction probability given the specified number of samples.<br>
* See {@link #reconstructionLogProbability(INDArray, int)} for more details
*
* @param data The data to calculate the log reconstruction probability
* @param numSamples Number of samples with which to base the reconstruction probability on.
* @return Column vector of reconstruction log probabilities for each example (shape: [numExamples,1])
*/
public INDArray reconstructionLogProbability(INDArray data, int numSamples) {
if (numSamples <= 0) {
throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + numSamples);
}
if (reconstructionDistribution instanceof LossFunctionWrapper) {
throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using " + "a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction " + "instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability");
}
//Forward pass through the encoder and mean for P(Z|X)
setInput(data);
VAEFwdHelper fwd = doForward(true, true);
IActivation afn = conf().getLayer().getActivationFn();
//Forward pass through logStd^2 for P(Z|X)
INDArray pzxLogStd2W = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
INDArray pzxLogStd2b = params.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
INDArray meanZ = fwd.pzxMeanPreOut;
INDArray logStdev2Z = fwd.encoderActivations[fwd.encoderActivations.length - 1].mmul(pzxLogStd2W).addiRowVector(pzxLogStd2b);
pzxActivationFn.getActivation(meanZ, false);
pzxActivationFn.getActivation(logStdev2Z, false);
INDArray pzxSigma = Transforms.exp(logStdev2Z, false);
Transforms.sqrt(pzxSigma, false);
int minibatch = input.size(0);
int size = fwd.pzxMeanPreOut.size(1);
INDArray pxzw = params.get(VariationalAutoencoderParamInitializer.PXZ_W);
INDArray pxzb = params.get(VariationalAutoencoderParamInitializer.PXZ_B);
INDArray[] decoderWeights = new INDArray[decoderLayerSizes.length];
INDArray[] decoderBiases = new INDArray[decoderLayerSizes.length];
for (int i = 0; i < decoderLayerSizes.length; i++) {
String wKey = "d" + i + WEIGHT_KEY_SUFFIX;
String bKey = "d" + i + BIAS_KEY_SUFFIX;
decoderWeights[i] = params.get(wKey);
decoderBiases[i] = params.get(bKey);
}
INDArray sumReconstructionNegLogProbability = null;
for (int i = 0; i < numSamples; i++) {
INDArray e = Nd4j.randn(minibatch, size);
//z = mu + sigma * e, with e ~ N(0,1)
INDArray z = e.muli(pzxSigma).addi(meanZ);
//Do forward pass through decoder
int nDecoderLayers = decoderLayerSizes.length;
INDArray currentActivations = z;
for (int j = 0; j < nDecoderLayers; j++) {
currentActivations = currentActivations.mmul(decoderWeights[j]).addiRowVector(decoderBiases[j]);
afn.getActivation(currentActivations, false);
}
//And calculate reconstruction distribution preOut
INDArray pxzDistributionPreOut = currentActivations.mmul(pxzw).addiRowVector(pxzb);
if (i == 0) {
sumReconstructionNegLogProbability = reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut);
} else {
sumReconstructionNegLogProbability.addi(reconstructionDistribution.exampleNegLogProbability(data, pxzDistributionPreOut));
}
}
setInput(null);
return sumReconstructionNegLogProbability.divi(-numSamples);
}
use of org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper in project deeplearning4j by deeplearning4j.
the class VariationalAutoencoder method reconstructionError.
/**
* Return the reconstruction error for this variational autoencoder.<br>
* <b>NOTE (important):</b> This method is used ONLY for VAEs that have a standard neural network loss function (i.e.,
* an {@link org.nd4j.linalg.lossfunctions.ILossFunction} instance such as mean squared error) instead of using a
* probabilistic reconstruction distribution P(x|z) for the reconstructions (as presented in the VAE architecture by
* Kingma and Welling).<br>
* You can check if the VAE has a loss function using {@link #hasLossFunction()}<br>
* Consequently, the reconstruction error is a simple deterministic function (no Monte-Carlo sampling is required,
* unlike {@link #reconstructionProbability(INDArray, int)} and {@link #reconstructionLogProbability(INDArray, int)})
*
* @param data The data to calculate the reconstruction error on
* @return Column vector of reconstruction errors for each example (shape: [numExamples,1])
*/
public INDArray reconstructionError(INDArray data) {
if (!hasLossFunction()) {
throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is " + "configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction " + "distribution, use the reconstructionProbability or reconstructionLogProbability methods");
}
INDArray pZXMean = activate(data, false);
//Not probabilistic -> "mean" == output
INDArray reconstruction = generateAtMeanGivenZ(pZXMean);
if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
CompositeReconstructionDistribution c = (CompositeReconstructionDistribution) reconstructionDistribution;
return c.computeLossFunctionScoreArray(data, reconstruction);
} else {
LossFunctionWrapper lfw = (LossFunctionWrapper) reconstructionDistribution;
ILossFunction lossFunction = lfw.getLossFunction();
// so we don't want to apply it again. i.e., we are passing the output, not the pre-output.
return lossFunction.computeScoreArray(data, reconstruction, new ActivationIdentity(), null);
}
}
use of org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper in project deeplearning4j by deeplearning4j.
the class TestMiscFunctions method testVaeReconstructionErrorWithKey.
@Test
public void testVaeReconstructionErrorWithKey() {
//Simple test. We CAN do a direct comparison here vs. local, as reconstruction error is deterministic
int nIn = 10;
MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list().layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().reconstructionDistribution(new LossFunctionWrapper(Activation.IDENTITY, new LossMSE())).nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(mlc);
net.init();
VariationalAutoencoder vae = (VariationalAutoencoder) net.getLayer(0);
List<Tuple2<Integer, INDArray>> toScore = new ArrayList<>();
for (int i = 0; i < 100; i++) {
INDArray arr = Nd4j.rand(1, nIn);
toScore.add(new Tuple2<Integer, INDArray>(i, arr));
}
JavaPairRDD<Integer, INDArray> rdd = sc.parallelizePairs(toScore);
JavaPairRDD<Integer, Double> reconstrErrors = rdd.mapPartitionsToPair(new VaeReconstructionErrorWithKeyFunction<Integer>(sc.broadcast(net.params()), sc.broadcast(mlc.toJson()), 16));
Map<Integer, Double> l = reconstrErrors.collectAsMap();
assertEquals(100, l.size());
for (int i = 0; i < 100; i++) {
assertTrue(l.containsKey(i));
INDArray localToScore = toScore.get(i)._2();
double localScore = vae.reconstructionError(localToScore).data().asDouble()[0];
assertEquals(localScore, l.get(i), 1e-6);
}
}
Aggregations