use of org.deeplearning4j.nn.graph.vertex.impl.LayerVertex in project deeplearning4j by deeplearning4j.
the class ComputationGraph method summary.
/**
* String detailing the architecture of the computation graph.
* Vertices are printed in a topological sort order.
* Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
* And the inputs to the vertex
* Will also give information about frozen layers/vertices, if any.
* @return Summary as a string
*/
public String summary() {
String ret = "\n";
ret += StringUtils.repeat("=", 140);
ret += "\n";
ret += String.format("%-40s%-15s%-15s%-30s %s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs");
ret += StringUtils.repeat("=", 140);
ret += "\n";
int frozenParams = 0;
for (int currVertexIdx : topologicalOrder) {
GraphVertex current = vertices[currVertexIdx];
String name = current.getVertexName();
String[] classNameArr = current.getClass().toString().split("\\.");
String className = classNameArr[classNameArr.length - 1];
String connections = "-";
if (!current.isInputVertex()) {
connections = configuration.getVertexInputs().get(name).toString();
}
String paramCount = "-";
String in = "-";
String out = "-";
String paramShape = "-";
if (current.hasLayer()) {
Layer currentLayer = ((LayerVertex) current).getLayer();
classNameArr = currentLayer.getClass().getName().split("\\.");
className = classNameArr[classNameArr.length - 1];
paramCount = String.valueOf(currentLayer.numParams());
if (currentLayer.numParams() > 0) {
paramShape = "";
in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
for (String aP : paraNames) {
String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
paramShape += aP + ":" + paramS + ", ";
}
paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
}
if (currentLayer instanceof FrozenLayer) {
frozenParams += currentLayer.numParams();
classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
className = "Frozen " + classNameArr[classNameArr.length - 1];
}
}
ret += String.format("%-40s%-15s%-15s%-30s %s", name + " (" + className + ")", in + "," + out, paramCount, paramShape, connections);
ret += "\n";
}
ret += StringUtils.repeat("-", 140);
ret += String.format("\n%30s %d", "Total Parameters: ", params().length());
ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams);
ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
ret += "\n";
ret += StringUtils.repeat("=", 140);
ret += "\n";
return ret;
}
Aggregations