Search in sources :

Example 6 with VariationalAutoencoder

use of org.deeplearning4j.nn.layers.variational.VariationalAutoencoder in project deeplearning4j by deeplearning4j.

the class TestMiscFunctions method testVaeReconstructionErrorWithKey.

@Test
public void testVaeReconstructionErrorWithKey() {
    //Simple test. We CAN do a direct comparison here vs. local, as reconstruction error is deterministic
    int nIn = 10;
    MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list().layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().reconstructionDistribution(new LossFunctionWrapper(Activation.IDENTITY, new LossMSE())).nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(mlc);
    net.init();
    VariationalAutoencoder vae = (VariationalAutoencoder) net.getLayer(0);
    List<Tuple2<Integer, INDArray>> toScore = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        INDArray arr = Nd4j.rand(1, nIn);
        toScore.add(new Tuple2<Integer, INDArray>(i, arr));
    }
    JavaPairRDD<Integer, INDArray> rdd = sc.parallelizePairs(toScore);
    JavaPairRDD<Integer, Double> reconstrErrors = rdd.mapPartitionsToPair(new VaeReconstructionErrorWithKeyFunction<Integer>(sc.broadcast(net.params()), sc.broadcast(mlc.toJson()), 16));
    Map<Integer, Double> l = reconstrErrors.collectAsMap();
    assertEquals(100, l.size());
    for (int i = 0; i < 100; i++) {
        assertTrue(l.containsKey(i));
        INDArray localToScore = toScore.get(i)._2();
        double localScore = vae.reconstructionError(localToScore).data().asDouble()[0];
        assertEquals(localScore, l.get(i), 1e-6);
    }
}
Also used : VariationalAutoencoder(org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LossMSE(org.nd4j.linalg.lossfunctions.impl.LossMSE) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) LossFunctionWrapper(org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

VariationalAutoencoder (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 Layer (org.deeplearning4j.nn.api.Layer)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)3 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 Tuple2 (scala.Tuple2)2 ArrayList (java.util.ArrayList)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 LossFunctionWrapper (org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper)1 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)1 Test (org.junit.Test)1 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)1 LossMSE (org.nd4j.linalg.lossfunctions.impl.LossMSE)1