use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class LayerConfigTest method testMomentumLayerwiseOverride.
@Test
public void testMomentumLayerwiseOverride() {
Map<Integer, Double> testMomentumAfter = new HashMap<>();
testMomentumAfter.put(0, 0.1);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.NESTEROVS).momentum(1.0).momentumAfter(testMomentumAfter).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
assertEquals(1.0, conf.getConf(0).getLayer().getMomentum(), 0.0);
assertEquals(1.0, conf.getConf(1).getLayer().getMomentum(), 0.0);
assertEquals(0.1, conf.getConf(0).getLayer().getMomentumSchedule().get(0), 0.0);
assertEquals(0.1, conf.getConf(1).getLayer().getMomentumSchedule().get(0), 0.0);
Map<Integer, Double> testMomentumAfter2 = new HashMap<>();
testMomentumAfter2.put(0, 0.2);
conf = new NeuralNetConfiguration.Builder().updater(Updater.NESTEROVS).momentum(1.0).momentumAfter(testMomentumAfter).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).momentum(2.0).momentumAfter(testMomentumAfter2).build()).build();
net = new MultiLayerNetwork(conf);
net.init();
assertEquals(1.0, conf.getConf(0).getLayer().getMomentum(), 0.0);
assertEquals(2.0, conf.getConf(1).getLayer().getMomentum(), 0.0);
assertEquals(0.1, conf.getConf(0).getLayer().getMomentumSchedule().get(0), 0.0);
assertEquals(0.2, conf.getConf(1).getLayer().getMomentumSchedule().get(0), 0.0);
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class LayerConfigValidationTest method testNesterovsNotSetGlobal.
@Test
public void testNesterovsNotSetGlobal() {
// Warnings only thrown
Map<Integer, Double> testMomentumAfter = new HashMap<>();
testMomentumAfter.put(0, 0.1);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().momentum(1.0).momentumAfter(testMomentumAfter).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class LayerConfigValidationTest method testDropConnect.
@Test
public void testDropConnect() {
// Warning thrown only since some layers may not have l1 or l2
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).useDropConnect(true).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class LayerConfigValidationTest method testWeightInitDistNotSet.
@Test
public void testWeightInitDistNotSet() {
// Warning thrown only since global dist can be set with a different weight init locally
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).dist(new GaussianDistribution(1e-3, 2)).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
}
use of org.deeplearning4j.nn.multilayer.MultiLayerNetwork in project deeplearning4j by deeplearning4j.
the class ComputationGraph method rnnActivateUsingStoredState.
/**
* Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:<br>
* (a) like rnnTimeStep does forward pass using stored state for RNN layers, and<br>
* (b) unlike rnnTimeStep does not modify the RNN layer state<br>
* Therefore multiple calls to this method with the same input should have the same output.<br>
* Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time.
*
* @param inputs Input to network
* @param training Whether training or not
* @param storeLastForTBPTT set to true if used as part of truncated BPTT training
* @return Activations for each layer (including input, as per feedforward() etc)
*/
public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] inputs, boolean training, boolean storeLastForTBPTT) {
Map<String, INDArray> layerActivations = new HashMap<>();
//Do forward pass according to the topological ordering of the network
for (int currVertexIdx : topologicalOrder) {
GraphVertex current = vertices[currVertexIdx];
if (current.isInputVertex()) {
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray input = inputs[current.getVertexIndex()];
layerActivations.put(current.getVertexName(), input);
for (VertexIndices v : inputsTo) {
int vIdx = v.getVertexIndex();
int vIdxInputNum = v.getVertexEdgeNumber();
//This input: the 'vIdxInputNum'th input to vertex 'vIdx'
//TODO When to dup?
vertices[vIdx].setInput(vIdxInputNum, input.dup());
}
} else {
INDArray out;
if (current.hasLayer()) {
Layer l = current.getLayer();
if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
} else if (l instanceof MultiLayerNetwork) {
List<INDArray> temp = ((MultiLayerNetwork) l).rnnActivateUsingStoredState(current.getInputs()[0], training, storeLastForTBPTT);
out = temp.get(temp.size() - 1);
} else {
//non-recurrent layer
out = current.doForward(training);
}
layerActivations.put(current.getVertexName(), out);
} else {
out = current.doForward(training);
}
//Now, set the inputs for the next vertices:
VertexIndices[] outputsTo = current.getOutputVertices();
if (outputsTo != null) {
for (VertexIndices v : outputsTo) {
int vIdx = v.getVertexIndex();
int inputNum = v.getVertexEdgeNumber();
//This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
vertices[vIdx].setInput(inputNum, out);
}
}
}
}
return layerActivations;
}
Aggregations