use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class TransferLearningComplex method testLessSimpleMergeBackProp.
@Test
public void testLessSimpleMergeBackProp() {
NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
/*
inCentre inRight
| |
denseCentre0 denseRight0
| |
|------ mergeRight ------|
| |
outCentre outRight
*/
ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).build(), "denseCentre0").addLayer("denseRight0", new DenseLayer.Builder().nIn(3).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").setOutputs("outCentre").build();
ComputationGraph modelToTune = new ComputationGraph(conf);
modelToTune.init();
modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 3) }, new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) });
INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
assertTrue(modelNow.getLayer("denseCentre0") instanceof FrozenLayer);
int n = 0;
while (n < 5) {
if (n == 0) {
//confirm activations out of the merge are equivalent
assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelNow.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
}
//confirm activations out of frozen vertex is the same as the input to the other model
modelToTune.fit(randData);
modelNow.fit(randData);
assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params());
n++;
}
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class TransferLearningComplex method testMergeAndFreeze.
@Test
public void testMergeAndFreeze() {
// in1 -> A -> B -> merge, in2 -> C -> merge -> D -> out
//Goal here: test a number of things...
// (a) Ensure that freezing C doesn't impact A and B. Only C should be frozen in this config
// (b) Test global override (should be selective)
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.ADAM).learningRate(1e-4).activation(Activation.LEAKYRELU).graphBuilder().addInputs("in1", "in2").addLayer("A", new DenseLayer.Builder().nIn(10).nOut(9).build(), "in1").addLayer("B", new DenseLayer.Builder().nIn(9).nOut(8).build(), "A").addLayer("C", new DenseLayer.Builder().nIn(7).nOut(6).build(), "in2").addLayer("D", new DenseLayer.Builder().nIn(8 + 7).nOut(5).build(), "B", "C").addLayer("out", new OutputLayer.Builder().nIn(5).nOut(4).build(), "D").setOutputs("out").build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
int[] topologicalOrder = graph.topologicalSortOrder();
org.deeplearning4j.nn.graph.vertex.GraphVertex[] vertices = graph.getVertices();
for (int i = 0; i < topologicalOrder.length; i++) {
org.deeplearning4j.nn.graph.vertex.GraphVertex v = vertices[topologicalOrder[i]];
log.info(i + "\t" + v.getVertexName());
}
ComputationGraph graph2 = new TransferLearning.GraphBuilder(graph).fineTuneConfiguration(new FineTuneConfiguration.Builder().learningRate(2e-2).build()).setFeatureExtractor("C").build();
boolean cFound = false;
Layer[] layers = graph2.getLayers();
for (Layer l : layers) {
String name = l.conf().getLayer().getLayerName();
log.info(name + "\t frozen: " + (l instanceof FrozenLayer));
if ("C".equals(l.conf().getLayer().getLayerName())) {
//Only C should be frozen in this config
cFound = true;
assertTrue(name, l instanceof FrozenLayer);
} else {
assertFalse(name, l instanceof FrozenLayer);
}
//Also check config:
assertEquals(Updater.ADAM, l.conf().getLayer().getUpdater());
assertEquals(2e-2, l.conf().getLayer().getLearningRate(), 1e-5);
assertEquals(Activation.LEAKYRELU.getActivationFunction(), l.conf().getLayer().getActivationFn());
}
assertTrue(cFound);
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method clone.
@Override
public ComputationGraph clone() {
ComputationGraph cg = new ComputationGraph(configuration.clone());
cg.init(params().dup(), false);
if (solver != null) {
//If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
ComputationGraphUpdater u = this.getUpdater();
INDArray updaterState = u.getStateViewArray();
if (updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState.dup());
}
}
cg.listeners = this.listeners;
for (int i = 0; i < topologicalOrder.length; i++) {
if (!vertices[topologicalOrder[i]].hasLayer())
continue;
String layerName = vertices[topologicalOrder[i]].getVertexName();
if (getLayer(layerName) instanceof FrozenLayer) {
cg.getVertex(layerName).setLayerAsFrozen();
}
}
return cg;
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method summary.
/**
* String detailing the architecture of the multilayernetwork.
* Columns are LayerIndex with layer type, nIn, nOut, Total number of parameters and the Shapes of the parameters
* Will also give information about frozen layers, if any.
* @return Summary as a string
*/
public String summary() {
String ret = "\n";
ret += StringUtils.repeat("=", 140);
ret += "\n";
ret += String.format("%-40s%-15s%-15s%-30s\n", "LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape");
ret += StringUtils.repeat("=", 140);
ret += "\n";
int frozenParams = 0;
for (Layer currentLayer : layers) {
String name = String.valueOf(currentLayer.getIndex());
String paramShape = "-";
String in = "-";
String out = "-";
String[] classNameArr = currentLayer.getClass().getName().split("\\.");
String className = classNameArr[classNameArr.length - 1];
String paramCount = String.valueOf(currentLayer.numParams());
if (currentLayer.numParams() > 0) {
paramShape = "";
in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
for (String aP : paraNames) {
String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
paramShape += aP + ":" + paramS + ", ";
}
paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
}
if (currentLayer instanceof FrozenLayer) {
frozenParams += currentLayer.numParams();
classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
className = "Frozen " + classNameArr[classNameArr.length - 1];
}
ret += String.format("%-40s%-15s%-15s%-30s", name + " (" + className + ")", in + "," + out, paramCount, paramShape);
ret += "\n";
}
ret += StringUtils.repeat("-", 140);
ret += String.format("\n%30s %d", "Total Parameters: ", params().length());
ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams);
ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
ret += "\n";
ret += StringUtils.repeat("=", 140);
ret += "\n";
return ret;
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class MultiLayerNetwork method calcBackpropGradients.
/** Calculate gradients and errors. Used in two places:
* (a) backprop (for standard multi layer network learning)
* (b) backpropGradient (layer method, for when MultiLayerNetwork is used as a layer)
* @param epsilon Errors (technically errors .* activations). Not used if withOutputLayer = true
* @param withOutputLayer if true: assume last layer is output layer, and calculate errors based on labels. In this
* case, the epsilon input is not used (may/should be null).
* If false: calculate backprop gradients
* @return Gradients and the error (epsilon) at the input
*/
protected Pair<Gradient, INDArray> calcBackpropGradients(INDArray epsilon, boolean withOutputLayer) {
if (flattenedGradients == null)
initGradientsView();
String multiGradientKey;
Gradient gradient = new DefaultGradient(flattenedGradients);
Layer currLayer;
//calculate and apply the backward gradient for every layer
/**
* Skip the output layer for the indexing and just loop backwards updating the coefficients for each layer.
* (when withOutputLayer == true)
*
* Activate applies the activation function for each layer and sets that as the input for the following layer.
*
* Typical literature contains most trivial case for the error calculation: wT * weights
* This interpretation transpose a few things to get mini batch because ND4J is rows vs columns organization for params
*/
int numLayers = getnLayers();
//Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer
LinkedList<Triple<String, INDArray, Character>> gradientList = new LinkedList<>();
int layerFrom;
Pair<Gradient, INDArray> currPair;
if (withOutputLayer) {
if (!(getOutputLayer() instanceof IOutputLayer)) {
log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
return null;
}
IOutputLayer outputLayer = (IOutputLayer) getOutputLayer();
if (labels == null)
throw new IllegalStateException("No labels found");
outputLayer.setLabels(labels);
currPair = outputLayer.backpropGradient(null);
for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
String origName = entry.getKey();
multiGradientKey = String.valueOf(numLayers - 1) + "_" + origName;
gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), currPair.getFirst().flatteningOrderForVariable(origName)));
}
if (getLayerWiseConfigurations().getInputPreProcess(numLayers - 1) != null)
currPair = new Pair<>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(numLayers - 1).backprop(currPair.getSecond(), getInputMiniBatchSize()));
layerFrom = numLayers - 2;
} else {
currPair = new Pair<>(null, epsilon);
layerFrom = numLayers - 1;
}
// Calculate gradients for previous layers & drops output layer in count
for (int j = layerFrom; j >= 0; j--) {
currLayer = getLayer(j);
if (currLayer instanceof FrozenLayer)
break;
currPair = currLayer.backpropGradient(currPair.getSecond());
LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<>();
for (Map.Entry<String, INDArray> entry : currPair.getFirst().gradientForVariable().entrySet()) {
String origName = entry.getKey();
multiGradientKey = String.valueOf(j) + "_" + origName;
tempList.addFirst(new Triple<>(multiGradientKey, entry.getValue(), currPair.getFirst().flatteningOrderForVariable(origName)));
}
for (Triple<String, INDArray, Character> triple : tempList) gradientList.addFirst(triple);
//Pass epsilon through input processor before passing to next layer (if applicable)
if (getLayerWiseConfigurations().getInputPreProcess(j) != null)
currPair = new Pair<>(currPair.getFirst(), getLayerWiseConfigurations().getInputPreProcess(j).backprop(currPair.getSecond(), getInputMiniBatchSize()));
}
//Add gradients to Gradients (map), in correct order
for (Triple<String, INDArray, Character> triple : gradientList) {
gradient.setGradientFor(triple.getFirst(), triple.getSecond(), triple.getThird());
}
return new Pair<>(gradient, currPair.getSecond());
}
Aggregations