Search in sources :

Example 31 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class TestInvalidInput method testInputNinMismatchBidirectionalLSTM.

@Test
public void testInputNinMismatchBidirectionalLSTM() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).build()).layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    try {
        net.fit(Nd4j.create(1, 10, 5), Nd4j.create(1, 5, 5));
        fail("Expected DL4JException");
    } catch (DL4JException e) {
        System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail("Expected DL4JException");
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DL4JException(org.deeplearning4j.exception.DL4JException) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 32 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class TestInvalidInput method testInputNinMismatchConvolutional.

@Test
public void testInputNinMismatchConvolutional() {
    //Rank 4 input, but input depth does not match nIn depth
    int h = 16;
    int w = 16;
    int d = 3;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build()).layer(1, new OutputLayer.Builder().nOut(10).build()).setInputType(InputType.convolutional(h, w, d)).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    try {
        net.feedForward(Nd4j.create(1, 5, h, w));
        fail("Expected DL4JException");
    } catch (DL4JException e) {
        System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail("Expected DL4JException");
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DL4JException(org.deeplearning4j.exception.DL4JException) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 33 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class TestInvalidInput method testInvalidRnnTimeStep.

@Test
public void testInvalidRnnTimeStep() {
    //Idea: Using rnnTimeStep with a different number of examples between calls
    //(i.e., not calling reset between time steps)
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()).layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.rnnTimeStep(Nd4j.create(3, 5, 10));
    try {
        net.rnnTimeStep(Nd4j.create(5, 5, 10));
        fail("Expected DL4JException");
    } catch (DL4JException e) {
        System.out.println("testInvalidRnnTimeStep(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail("Expected DL4JException");
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DL4JException(org.deeplearning4j.exception.DL4JException) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 34 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class TestInvalidInput method testInputNinRank2Convolutional.

@Test
public void testInputNinRank2Convolutional() {
    //Rank 2 input, instead of rank 4 input. For example, forgetting the
    int h = 16;
    int w = 16;
    int d = 3;
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build()).layer(1, new OutputLayer.Builder().nOut(10).build()).setInputType(InputType.convolutional(h, w, d)).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    try {
        net.feedForward(Nd4j.create(1, 5 * h * w));
        fail("Expected DL4JException");
    } catch (DL4JException e) {
        System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail("Expected DL4JException");
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DL4JException(org.deeplearning4j.exception.DL4JException) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 35 with MultiLayerNetwork

use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.

the class TestInvalidInput method testLabelsNOutMismatchRnnOutputLayer.

@Test
public void testLabelsNOutMismatchRnnOutputLayer() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()).layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    try {
        net.fit(Nd4j.create(1, 5, 8), Nd4j.create(1, 10, 8));
        fail("Expected DL4JException");
    } catch (DL4JException e) {
        System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail("Expected DL4JException");
    }
}
Also used : MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DL4JException(org.deeplearning4j.exception.DL4JException) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Aggregations

MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)326 Test (org.junit.Test)277 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)206 INDArray (org.nd4j.linalg.api.ndarray.INDArray)166 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)111 DataSet (org.nd4j.linalg.dataset.DataSet)91 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)70 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)49 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)43 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)41 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)40 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)38 Random (java.util.Random)34 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)30 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)28 DL4JException (org.deeplearning4j.exception.DL4JException)20 Layer (org.deeplearning4j.nn.api.Layer)20 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)20 File (java.io.File)19 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)19