use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class TransferLearningHelper method initHelperMLN.
private void initHelperMLN() {
if (applyFrozen) {
org.deeplearning4j.nn.api.Layer[] layers = origMLN.getLayers();
for (int i = frozenTill; i >= 0; i--) {
//unchecked?
layers[i] = new FrozenLayer(layers[i]);
}
origMLN.setLayers(layers);
}
for (int i = 0; i < origMLN.getnLayers(); i++) {
if (origMLN.getLayer(i) instanceof FrozenLayer) {
frozenInputLayer = i;
}
}
List<NeuralNetConfiguration> allConfs = new ArrayList<>();
for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
allConfs.add(origMLN.getLayer(i).conf());
}
MultiLayerConfiguration c = origMLN.getLayerWiseConfigurations();
unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder().backprop(c.isBackprop()).inputPreProcessors(c.getInputPreProcessors()).pretrain(c.isPretrain()).backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength()).tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs).build());
unFrozenSubsetMLN.init();
//copy over params
for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) {
unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).setParams(origMLN.getLayer(i).params());
}
//unFrozenSubsetMLN.setListeners(origMLN.getListeners());
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class TransferLearningHelper method initHelperGraph.
/**
* Runs through the comp graph and saves off a new model that is simply the "unfrozen" part of the origModel
* This "unfrozen" model is then used for training with featurized data
*/
private void initHelperGraph() {
int[] backPropOrder = origGraph.topologicalSortOrder().clone();
ArrayUtils.reverse(backPropOrder);
Set<String> allFrozen = new HashSet<>();
if (applyFrozen) {
Collections.addAll(allFrozen, frozenOutputAt);
}
for (int i = 0; i < backPropOrder.length; i++) {
org.deeplearning4j.nn.graph.vertex.GraphVertex gv = origGraph.getVertices()[backPropOrder[i]];
if (applyFrozen && allFrozen.contains(gv.getVertexName())) {
if (gv.hasLayer()) {
//Need to freeze this layer
org.deeplearning4j.nn.api.Layer l = gv.getLayer();
gv.setLayerAsFrozen();
//We also need to place the layer in the CompGraph Layer[] (replacing the old one)
//This could no doubt be done more efficiently
org.deeplearning4j.nn.api.Layer[] layers = origGraph.getLayers();
for (int j = 0; j < layers.length; j++) {
if (layers[j] == l) {
//Place the new frozen layer to replace the original layer
layers[j] = gv.getLayer();
break;
}
}
}
//Also: mark any inputs as to be frozen also
VertexIndices[] inputs = gv.getInputVertices();
if (inputs != null && inputs.length > 0) {
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String alsoFreeze = origGraph.getVertices()[inputVertexIdx].getVertexName();
allFrozen.add(alsoFreeze);
}
}
} else {
if (gv.hasLayer()) {
if (gv.getLayer() instanceof FrozenLayer) {
allFrozen.add(gv.getVertexName());
//also need to add parents to list of allFrozen
VertexIndices[] inputs = gv.getInputVertices();
if (inputs != null && inputs.length > 0) {
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String alsoFrozen = origGraph.getVertices()[inputVertexIdx].getVertexName();
allFrozen.add(alsoFrozen);
}
}
}
}
}
}
for (int i = 0; i < backPropOrder.length; i++) {
org.deeplearning4j.nn.graph.vertex.GraphVertex gv = origGraph.getVertices()[backPropOrder[i]];
String gvName = gv.getVertexName();
//is it an unfrozen vertex that has an input vertex that is frozen?
if (!allFrozen.contains(gvName) && !gv.isInputVertex()) {
VertexIndices[] inputs = gv.getInputVertices();
for (int j = 0; j < inputs.length; j++) {
int inputVertexIdx = inputs[j].getVertexIndex();
String inputVertex = origGraph.getVertices()[inputVertexIdx].getVertexName();
if (allFrozen.contains(inputVertex)) {
frozenInputVertices.add(inputVertex);
}
}
}
}
TransferLearning.GraphBuilder builder = new TransferLearning.GraphBuilder(origGraph);
for (String toRemove : allFrozen) {
if (frozenInputVertices.contains(toRemove)) {
builder.removeVertexKeepConnections(toRemove);
} else {
builder.removeVertexAndConnections(toRemove);
}
}
Set<String> frozenInputVerticesSorted = new HashSet<>();
frozenInputVerticesSorted.addAll(origGraph.getConfiguration().getNetworkInputs());
frozenInputVerticesSorted.removeAll(allFrozen);
//remove input vertices - just to add back in a predictable order
for (String existingInput : frozenInputVerticesSorted) {
builder.removeVertexKeepConnections(existingInput);
}
frozenInputVerticesSorted.addAll(frozenInputVertices);
//Sort all inputs to the computation graph - in order to have a predictable order
graphInputs = new ArrayList(frozenInputVerticesSorted);
Collections.sort(graphInputs);
for (String asInput : frozenInputVerticesSorted) {
//add back in the right order
builder.addInputs(asInput);
}
unFrozenSubsetGraph = builder.build();
copyOrigParamsToSubsetGraph();
if (frozenInputVertices.isEmpty()) {
throw new IllegalArgumentException("No frozen layers found");
}
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class LayerUpdater method update.
@Override
public void update(Layer layer, Gradient gradient, int iteration, int miniBatchSize) {
String paramName;
INDArray gradientOrig, gradient2;
GradientUpdater updater;
if (layer instanceof FrozenLayer)
return;
preApply(layer, gradient, iteration);
for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) {
paramName = gradientPair.getKey();
if (!layer.conf().isPretrain() && PretrainParamInitializer.VISIBLE_BIAS_KEY.equals(paramName.split("_")[0]))
continue;
gradientOrig = gradientPair.getValue();
LearningRatePolicy decay = layer.conf().getLearningRatePolicy();
if (decay != LearningRatePolicy.None || layer.conf().getLayer().getUpdater() == org.deeplearning4j.nn.conf.Updater.NESTEROVS)
applyLrDecayPolicy(decay, layer, iteration, paramName);
updater = init(paramName, layer);
gradient2 = updater.getGradient(gradientOrig, iteration);
postApply(layer, gradient2, paramName, miniBatchSize);
gradient.setGradientFor(paramName, gradient2);
}
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method summary.
/**
* String detailing the architecture of the computation graph.
* Vertices are printed in a topological sort order.
* Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
* And the inputs to the vertex
* Will also give information about frozen layers/vertices, if any.
* @return Summary as a string
*/
public String summary() {
String ret = "\n";
ret += StringUtils.repeat("=", 140);
ret += "\n";
ret += String.format("%-40s%-15s%-15s%-30s %s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs");
ret += StringUtils.repeat("=", 140);
ret += "\n";
int frozenParams = 0;
for (int currVertexIdx : topologicalOrder) {
GraphVertex current = vertices[currVertexIdx];
String name = current.getVertexName();
String[] classNameArr = current.getClass().toString().split("\\.");
String className = classNameArr[classNameArr.length - 1];
String connections = "-";
if (!current.isInputVertex()) {
connections = configuration.getVertexInputs().get(name).toString();
}
String paramCount = "-";
String in = "-";
String out = "-";
String paramShape = "-";
if (current.hasLayer()) {
Layer currentLayer = ((LayerVertex) current).getLayer();
classNameArr = currentLayer.getClass().getName().split("\\.");
className = classNameArr[classNameArr.length - 1];
paramCount = String.valueOf(currentLayer.numParams());
if (currentLayer.numParams() > 0) {
paramShape = "";
in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
for (String aP : paraNames) {
String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
paramShape += aP + ":" + paramS + ", ";
}
paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
}
if (currentLayer instanceof FrozenLayer) {
frozenParams += currentLayer.numParams();
classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
className = "Frozen " + classNameArr[classNameArr.length - 1];
}
}
ret += String.format("%-40s%-15s%-15s%-30s %s", name + " (" + className + ")", in + "," + out, paramCount, paramShape, connections);
ret += "\n";
}
ret += StringUtils.repeat("-", 140);
ret += String.format("\n%30s %d", "Total Parameters: ", params().length());
ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams);
ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
ret += "\n";
ret += StringUtils.repeat("=", 140);
ret += "\n";
return ret;
}
use of org.deeplearning4j.nn.layers.FrozenLayer in project deeplearning4j by deeplearning4j.
the class ComputationGraph method calcBackpropGradients.
/**
* Do backprop (gradient calculation)
*
* @param truncatedBPTT false: normal backprop. true: calculate gradients using truncated BPTT for RNN layers
* @param externalEpsilons null usually (for typical supervised learning). If not null (and length > 0) then assume that
* the user has provided some errors externally, as they would do for example in reinforcement
* learning situations.
*/
protected void calcBackpropGradients(boolean truncatedBPTT, INDArray... externalEpsilons) {
if (flattenedGradients == null)
initGradientsView();
LinkedList<Triple<String, INDArray, Character>> gradients = new LinkedList<>();
//Do backprop according to the reverse of the topological ordering of the network
//If true: already set epsilon for this vertex; later epsilons should be *added* to the existing one, not set
boolean[] setVertexEpsilon = new boolean[topologicalOrder.length];
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
GraphVertex current = vertices[topologicalOrder[i]];
if (current.isInputVertex())
//No op
continue;
//FIXME: make the frozen vertex feature extraction more flexible
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer)
break;
if (current.isOutputVertex()) {
//Two reasons for a vertex to be an output vertex:
//(a) it's an output layer (i.e., instanceof IOutputLayer), or
//(b) it's a normal layer, but it has been marked as an output layer for use in external errors - for reinforcement learning, for example
int thisOutputNumber = configuration.getNetworkOutputs().indexOf(current.getVertexName());
if (current.getLayer() instanceof IOutputLayer) {
IOutputLayer outputLayer = (IOutputLayer) current.getLayer();
INDArray currLabels = labels[thisOutputNumber];
outputLayer.setLabels(currLabels);
} else {
current.setEpsilon(externalEpsilons[thisOutputNumber]);
setVertexEpsilon[topologicalOrder[i]] = true;
}
}
Pair<Gradient, INDArray[]> pair = current.doBackward(truncatedBPTT);
INDArray[] epsilons = pair.getSecond();
//Inputs to the current GraphVertex:
VertexIndices[] inputVertices = current.getInputVertices();
//Set epsilons for the vertices that provide inputs to this vertex:
if (inputVertices != null) {
int j = 0;
for (VertexIndices v : inputVertices) {
GraphVertex gv = vertices[v.getVertexIndex()];
if (setVertexEpsilon[gv.getVertexIndex()]) {
//This vertex: must output to multiple vertices... we want to add the epsilons here
INDArray currentEps = gv.getEpsilon();
//TODO: in some circumstances, it may be safe to do in-place add (but not always)
gv.setEpsilon(currentEps.add(epsilons[j++]));
} else {
gv.setEpsilon(epsilons[j++]);
}
setVertexEpsilon[gv.getVertexIndex()] = true;
}
}
if (pair.getFirst() != null) {
Gradient g = pair.getFirst();
Map<String, INDArray> map = g.gradientForVariable();
LinkedList<Triple<String, INDArray, Character>> tempList = new LinkedList<>();
for (Map.Entry<String, INDArray> entry : map.entrySet()) {
String origName = entry.getKey();
String newName = current.getVertexName() + "_" + origName;
tempList.addFirst(new Triple<>(newName, entry.getValue(), g.flatteningOrderForVariable(origName)));
}
for (Triple<String, INDArray, Character> t : tempList) gradients.addFirst(t);
}
}
//Now, add the gradients in the order we need them in for flattening (same as params order)
Gradient gradient = new DefaultGradient(flattenedGradients);
for (Triple<String, INDArray, Character> t : gradients) {
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
}
this.gradient = gradient;
}
Aggregations