Search in sources :

Example 1 with LayerVertex

use of org.deeplearning4j.nn.graph.vertex.impl.LayerVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraph method summary.

/**
     * String detailing the architecture of the computation graph.
     * Vertices are printed in a topological sort order.
     * Columns are Vertex Names with layer/vertex type, nIn, nOut, Total number of parameters and the Shapes of the parameters
     * And the inputs to the vertex
     * Will also give information about frozen layers/vertices, if any.
     * @return Summary as a string
     */
public String summary() {
    String ret = "\n";
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    ret += String.format("%-40s%-15s%-15s%-30s %s\n", "VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs");
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    int frozenParams = 0;
    for (int currVertexIdx : topologicalOrder) {
        GraphVertex current = vertices[currVertexIdx];
        String name = current.getVertexName();
        String[] classNameArr = current.getClass().toString().split("\\.");
        String className = classNameArr[classNameArr.length - 1];
        String connections = "-";
        if (!current.isInputVertex()) {
            connections = configuration.getVertexInputs().get(name).toString();
        }
        String paramCount = "-";
        String in = "-";
        String out = "-";
        String paramShape = "-";
        if (current.hasLayer()) {
            Layer currentLayer = ((LayerVertex) current).getLayer();
            classNameArr = currentLayer.getClass().getName().split("\\.");
            className = classNameArr[classNameArr.length - 1];
            paramCount = String.valueOf(currentLayer.numParams());
            if (currentLayer.numParams() > 0) {
                paramShape = "";
                in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn());
                out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut());
                Set<String> paraNames = currentLayer.conf().getLearningRateByParam().keySet();
                for (String aP : paraNames) {
                    String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape());
                    paramShape += aP + ":" + paramS + ", ";
                }
                paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString();
            }
            if (currentLayer instanceof FrozenLayer) {
                frozenParams += currentLayer.numParams();
                classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\.");
                className = "Frozen " + classNameArr[classNameArr.length - 1];
            }
        }
        ret += String.format("%-40s%-15s%-15s%-30s %s", name + " (" + className + ")", in + "," + out, paramCount, paramShape, connections);
        ret += "\n";
    }
    ret += StringUtils.repeat("-", 140);
    ret += String.format("\n%30s %d", "Total Parameters: ", params().length());
    ret += String.format("\n%30s %d", "Trainable Parameters: ", params().length() - frozenParams);
    ret += String.format("\n%30s %d", "Frozen Parameters: ", frozenParams);
    ret += "\n";
    ret += StringUtils.repeat("=", 140);
    ret += "\n";
    return ret;
}
Also used : LayerVertex(org.deeplearning4j.nn.graph.vertex.impl.LayerVertex) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer)

Aggregations

Layer (org.deeplearning4j.nn.api.Layer)1 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)1 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)1 LayerVertex (org.deeplearning4j.nn.graph.vertex.impl.LayerVertex)1 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)1