use of org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor in project deeplearning4j by deeplearning4j.
the class TestRenders method testHistogramComputationGraph.
@Test
public void testHistogramComputationGraph() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(3).build(), "input").addLayer("cnn2", new ConvolutionLayer.Builder(4, 4).stride(2, 2).padding(1, 1).nIn(1).nOut(3).build(), "input").addLayer("max1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn1", "cnn2").addLayer("output", new OutputLayer.Builder().nIn(7 * 7 * 6).nOut(10).build(), "max1").setOutputs("output").inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(28, 28, 1)).inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(28, 28, 1)).inputPreProcessor("output", new CnnToFeedForwardPreProcessor(7, 7, 6)).pretrain(false).backprop(true).build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
graph.setListeners(new HistogramIterationListener(1), new ScoreIterationListener(1));
DataSetIterator mnist = new MnistDataSetIterator(32, 640, false, true, false, 12345);
graph.fit(mnist);
}
use of org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor in project deeplearning4j by deeplearning4j.
the class MultipleEpochsIteratorTest method testCifarDataSetIteratorReset.
// use when checking cifar dataset iterator
@Ignore
@Test
public void testCifarDataSetIteratorReset() {
int epochs = 2;
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().regularization(false).learningRate(1.0).weightInit(WeightInit.XAVIER).seed(12345L).list().layer(0, new DenseLayer.Builder().nIn(400).nOut(50).activation(Activation.RELU).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(50).nOut(10).build()).pretrain(false).backprop(true).inputPreProcessor(0, new CnnToFeedForwardPreProcessor(20, 20, 1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
MultipleEpochsIterator ds = new MultipleEpochsIterator(epochs, new CifarDataSetIterator(10, 20, new int[] { 20, 20, 1 }));
net.fit(ds);
assertEquals(epochs, ds.epochs);
assertEquals(2, ds.batch);
}
use of org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor in project deeplearning4j by deeplearning4j.
the class ComputationGraphConfigurationTest method testJSONBasic2.
@Test
public void testJSONBasic2() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input").addLayer("max1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn1", "cnn2").addLayer("dnn1", new DenseLayer.Builder().nOut(7).build(), "max1").addLayer("max2", new SubsamplingLayer.Builder().build(), "max1").addLayer("output", new OutputLayer.Builder().nIn(7).nOut(10).build(), "dnn1", "max2").setOutputs("output").inputPreProcessor("cnn1", new FeedForwardToCnnPreProcessor(32, 32, 3)).inputPreProcessor("cnn2", new FeedForwardToCnnPreProcessor(32, 32, 3)).inputPreProcessor("dnn1", new CnnToFeedForwardPreProcessor(8, 8, 5)).pretrain(false).backprop(true).build();
String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson());
assertEquals(conf, conf2);
}
use of org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor in project deeplearning4j by deeplearning4j.
the class TestInvalidConfigurations method testCnnInvalidConfigOrInput_BadStrides.
@Test
public void testCnnInvalidConfigOrInput_BadStrides() {
//Idea: same as testCnnInvalidConfigPaddingStridesHeight() but network is fed incorrect sized data
// or equivalently, network is set up without using InputType functionality (hence missing validation there)
int depthIn = 3;
int hIn = 10;
int wIn = 10;
//Invalid: (10-3+0)/2+1 = 4.5
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Strict).list().layer(0, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(2, 2).padding(0, 0).nIn(depthIn).nOut(5).build()).layer(1, new OutputLayer.Builder().nIn(5 * 4 * 4).nOut(10).build()).inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
try {
net.feedForward(Nd4j.create(3, depthIn, hIn, wIn));
fail("Expected exception");
} catch (DL4JException e) {
System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail();
}
}
use of org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor in project deeplearning4j by deeplearning4j.
the class ConvolutionLayerSetupTest method testSubSamplingWithPadding.
@Test
public void testSubSamplingWithPadding() {
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0, //(28-2+0)/2+1 = 14
new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()).layer(1, //(14-2+2)/2+1 = 8 -> 8x8x3
new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()).layer(2, new OutputLayer.Builder().nOut(3).build());
new ConvolutionLayerSetup(builder, 28, 28, 1);
MultiLayerConfiguration conf = builder.build();
assertNotNull(conf.getInputPreProcess(2));
assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor);
CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) conf.getInputPreProcess(2);
assertEquals(8, proc.getInputHeight());
assertEquals(8, proc.getInputWidth());
assertEquals(3, proc.getNumChannels());
assertEquals(8 * 8 * 3, ((FeedForwardLayer) conf.getConf(2).getLayer()).getNIn());
}
Aggregations