use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class TestInvalidInput method testInputNinMismatchBidirectionalLSTM.
@Test
public void testInputNinMismatchBidirectionalLSTM() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).build()).layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
try {
net.fit(Nd4j.create(1, 10, 5), Nd4j.create(1, 5, 5));
fail("Expected DL4JException");
} catch (DL4JException e) {
System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail("Expected DL4JException");
}
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class TestInvalidInput method testInputNinMismatchConvolutional.
@Test
public void testInputNinMismatchConvolutional() {
//Rank 4 input, but input depth does not match nIn depth
int h = 16;
int w = 16;
int d = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build()).layer(1, new OutputLayer.Builder().nOut(10).build()).setInputType(InputType.convolutional(h, w, d)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
try {
net.feedForward(Nd4j.create(1, 5, h, w));
fail("Expected DL4JException");
} catch (DL4JException e) {
System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail("Expected DL4JException");
}
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class TestInvalidInput method testInvalidRnnTimeStep.
@Test
public void testInvalidRnnTimeStep() {
//Idea: Using rnnTimeStep with a different number of examples between calls
//(i.e., not calling reset between time steps)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()).layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.rnnTimeStep(Nd4j.create(3, 5, 10));
try {
net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected DL4JException");
} catch (DL4JException e) {
System.out.println("testInvalidRnnTimeStep(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail("Expected DL4JException");
}
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class TestInvalidInput method testInputNinRank2Convolutional.
@Test
public void testInputNinRank2Convolutional() {
//Rank 2 input, instead of rank 4 input. For example, forgetting the
int h = 16;
int w = 16;
int d = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build()).layer(1, new OutputLayer.Builder().nOut(10).build()).setInputType(InputType.convolutional(h, w, d)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
try {
net.feedForward(Nd4j.create(1, 5 * h * w));
fail("Expected DL4JException");
} catch (DL4JException e) {
System.out.println("testInputNinRank2Convolutional(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail("Expected DL4JException");
}
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class TestInvalidInput method testLabelsNOutMismatchRnnOutputLayer.
@Test
public void testLabelsNOutMismatchRnnOutputLayer() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()).layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
try {
net.fit(Nd4j.create(1, 5, 8), Nd4j.create(1, 10, 8));
fail("Expected DL4JException");
} catch (DL4JException e) {
System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail("Expected DL4JException");
}
}
Aggregations