Search in sources :

Example 1 with BaseRecurrentLayer

use of org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer in project deeplearning4j by deeplearning4j.

the class ComputationGraphTestRNN method testTruncatedBPTTVsBPTT.

@Test
public void testTruncatedBPTTVsBPTT() {
    //Under some (limited) circumstances, we expect BPTT and truncated BPTT to be identical
    //Specifically TBPTT over entire data vector
    int timeSeriesLength = 12;
    int miniBatchSize = 7;
    int nIn = 5;
    int nOut = 4;
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build(), "0").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.DISTRIBUTION).nIn(8).nOut(nOut).activation(Activation.SOFTMAX).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build(), "1").setOutputs("out").backprop(true).build();
    assertEquals(BackpropType.Standard, conf.getBackpropType());
    ComputationGraphConfiguration confTBPTT = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build(), "0").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.DISTRIBUTION).nIn(8).nOut(nOut).activation(Activation.SOFTMAX).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build(), "1").setOutputs("out").backprop(true).backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(timeSeriesLength).tBPTTBackwardLength(timeSeriesLength).build();
    assertEquals(BackpropType.TruncatedBPTT, confTBPTT.getBackpropType());
    Nd4j.getRandom().setSeed(12345);
    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();
    Nd4j.getRandom().setSeed(12345);
    ComputationGraph graphTBPTT = new ComputationGraph(confTBPTT);
    graphTBPTT.init();
    assertTrue(graphTBPTT.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT);
    assertTrue(graphTBPTT.getConfiguration().getTbpttFwdLength() == timeSeriesLength);
    assertTrue(graphTBPTT.getConfiguration().getTbpttBackLength() == timeSeriesLength);
    INDArray inputData = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength });
    INDArray labels = Nd4j.rand(new int[] { miniBatchSize, nOut, timeSeriesLength });
    graph.setInput(0, inputData);
    graph.setLabel(0, labels);
    graphTBPTT.setInput(0, inputData);
    graphTBPTT.setLabel(0, labels);
    graph.computeGradientAndScore();
    graphTBPTT.computeGradientAndScore();
    Pair<Gradient, Double> graphPair = graph.gradientAndScore();
    Pair<Gradient, Double> graphTbpttPair = graphTBPTT.gradientAndScore();
    assertEquals(graphPair.getFirst().gradientForVariable(), graphTbpttPair.getFirst().gradientForVariable());
    assertEquals(graphPair.getSecond(), graphTbpttPair.getSecond());
    //Check states: expect stateMap to be empty but tBpttStateMap to not be
    Map<String, INDArray> l0StateMLN = graph.rnnGetPreviousState(0);
    Map<String, INDArray> l0StateTBPTT = graphTBPTT.rnnGetPreviousState(0);
    Map<String, INDArray> l1StateMLN = graph.rnnGetPreviousState(0);
    Map<String, INDArray> l1StateTBPTT = graphTBPTT.rnnGetPreviousState(0);
    Map<String, INDArray> l0TBPTTState = ((BaseRecurrentLayer<?>) graph.getLayer(0)).rnnGetTBPTTState();
    Map<String, INDArray> l0TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) graphTBPTT.getLayer(0)).rnnGetTBPTTState();
    Map<String, INDArray> l1TBPTTState = ((BaseRecurrentLayer<?>) graph.getLayer(1)).rnnGetTBPTTState();
    Map<String, INDArray> l1TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) graphTBPTT.getLayer(1)).rnnGetTBPTTState();
    assertTrue(l0StateMLN.isEmpty());
    assertTrue(l0StateTBPTT.isEmpty());
    assertTrue(l1StateMLN.isEmpty());
    assertTrue(l1StateTBPTT.isEmpty());
    assertTrue(l0TBPTTState.isEmpty());
    assertTrue(l0TBPTTStateTBPTT.size() == 2);
    assertTrue(l1TBPTTState.isEmpty());
    assertTrue(l1TBPTTStateTBPTT.size() == 2);
    INDArray tbpttActL0 = l0TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
    INDArray tbpttActL1 = l1TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
    Map<String, INDArray> activations = graph.feedForward(inputData, true);
    INDArray l0Act = activations.get("0");
    INDArray l1Act = activations.get("1");
    INDArray expL0Act = l0Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0);
    INDArray expL1Act = l1Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0);
    assertEquals(tbpttActL0, expL0Act);
    assertEquals(tbpttActL1, expL1Act);
}
Also used : RnnOutputLayer(org.deeplearning4j.nn.conf.layers.RnnOutputLayer) Gradient(org.deeplearning4j.nn.gradient.Gradient) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) GravesLSTM(org.deeplearning4j.nn.layers.recurrent.GravesLSTM) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NormalDistribution(org.deeplearning4j.nn.conf.distribution.NormalDistribution) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) BaseRecurrentLayer(org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer) Test(org.junit.Test)

Example 2 with BaseRecurrentLayer

use of org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer in project deeplearning4j by deeplearning4j.

the class MultiLayerTestRNN method testTruncatedBPTTVsBPTT.

@Test
public void testTruncatedBPTTVsBPTT() {
    //Under some (limited) circumstances, we expect BPTT and truncated BPTT to be identical
    //Specifically TBPTT over entire data vector
    int timeSeriesLength = 12;
    int miniBatchSize = 7;
    int nIn = 5;
    int nOut = 4;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build()).layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build()).layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).weightInit(WeightInit.DISTRIBUTION).nIn(8).nOut(nOut).activation(Activation.SOFTMAX).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build()).backprop(true).build();
    assertEquals(BackpropType.Standard, conf.getBackpropType());
    MultiLayerConfiguration confTBPTT = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build()).layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(7).nOut(8).activation(Activation.TANH).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build()).layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).weightInit(WeightInit.DISTRIBUTION).nIn(8).nOut(nOut).activation(Activation.SOFTMAX).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 0.5)).build()).backprop(true).backpropType(BackpropType.TruncatedBPTT).tBPTTBackwardLength(timeSeriesLength).tBPTTForwardLength(timeSeriesLength).build();
    Nd4j.getRandom().setSeed(12345);
    MultiLayerNetwork mln = new MultiLayerNetwork(conf);
    mln.init();
    Nd4j.getRandom().setSeed(12345);
    MultiLayerNetwork mlnTBPTT = new MultiLayerNetwork(confTBPTT);
    mlnTBPTT.init();
    assertTrue(mlnTBPTT.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT);
    assertTrue(mlnTBPTT.getLayerWiseConfigurations().getTbpttFwdLength() == timeSeriesLength);
    assertTrue(mlnTBPTT.getLayerWiseConfigurations().getTbpttBackLength() == timeSeriesLength);
    INDArray inputData = Nd4j.rand(new int[] { miniBatchSize, nIn, timeSeriesLength });
    INDArray labels = Nd4j.rand(new int[] { miniBatchSize, nOut, timeSeriesLength });
    mln.setInput(inputData);
    mln.setLabels(labels);
    mlnTBPTT.setInput(inputData);
    mlnTBPTT.setLabels(labels);
    mln.computeGradientAndScore();
    mlnTBPTT.computeGradientAndScore();
    Pair<Gradient, Double> mlnPair = mln.gradientAndScore();
    Pair<Gradient, Double> tbpttPair = mlnTBPTT.gradientAndScore();
    assertEquals(mlnPair.getFirst().gradientForVariable(), tbpttPair.getFirst().gradientForVariable());
    assertEquals(mlnPair.getSecond(), tbpttPair.getSecond());
    //Check states: expect stateMap to be empty but tBpttStateMap to not be
    Map<String, INDArray> l0StateMLN = mln.rnnGetPreviousState(0);
    Map<String, INDArray> l0StateTBPTT = mlnTBPTT.rnnGetPreviousState(0);
    Map<String, INDArray> l1StateMLN = mln.rnnGetPreviousState(0);
    Map<String, INDArray> l1StateTBPTT = mlnTBPTT.rnnGetPreviousState(0);
    Map<String, INDArray> l0TBPTTStateMLN = ((BaseRecurrentLayer<?>) mln.getLayer(0)).rnnGetTBPTTState();
    Map<String, INDArray> l0TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) mlnTBPTT.getLayer(0)).rnnGetTBPTTState();
    Map<String, INDArray> l1TBPTTStateMLN = ((BaseRecurrentLayer<?>) mln.getLayer(1)).rnnGetTBPTTState();
    Map<String, INDArray> l1TBPTTStateTBPTT = ((BaseRecurrentLayer<?>) mlnTBPTT.getLayer(1)).rnnGetTBPTTState();
    assertTrue(l0StateMLN.isEmpty());
    assertTrue(l0StateTBPTT.isEmpty());
    assertTrue(l1StateMLN.isEmpty());
    assertTrue(l1StateTBPTT.isEmpty());
    assertTrue(l0TBPTTStateMLN.isEmpty());
    assertTrue(l0TBPTTStateTBPTT.size() == 2);
    assertTrue(l1TBPTTStateMLN.isEmpty());
    assertTrue(l1TBPTTStateTBPTT.size() == 2);
    INDArray tbpttActL0 = l0TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
    INDArray tbpttActL1 = l1TBPTTStateTBPTT.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION);
    List<INDArray> activations = mln.feedForward(inputData, true);
    INDArray l0Act = activations.get(1);
    INDArray l1Act = activations.get(2);
    INDArray expL0Act = l0Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0);
    INDArray expL1Act = l1Act.tensorAlongDimension(timeSeriesLength - 1, 1, 0);
    assertEquals(tbpttActL0, expL0Act);
    assertEquals(tbpttActL1, expL1Act);
}
Also used : RnnOutputLayer(org.deeplearning4j.nn.conf.layers.RnnOutputLayer) Gradient(org.deeplearning4j.nn.gradient.Gradient) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) GravesLSTM(org.deeplearning4j.nn.layers.recurrent.GravesLSTM) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NormalDistribution(org.deeplearning4j.nn.conf.distribution.NormalDistribution) BaseRecurrentLayer(org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer) Test(org.junit.Test)

Aggregations

NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)2 RnnOutputLayer (org.deeplearning4j.nn.conf.layers.RnnOutputLayer)2 Gradient (org.deeplearning4j.nn.gradient.Gradient)2 BaseRecurrentLayer (org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer)2 GravesLSTM (org.deeplearning4j.nn.layers.recurrent.GravesLSTM)2 Test (org.junit.Test)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1