Search in sources :

Example 1 with InputVertex

use of org.deeplearning4j.nn.graph.vertex.impl.InputVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraph method init.

/**
     * Initialize the ComputationGraph, optionally with an existing parameters array.
     * If an existing parameters array is specified, it will be used (and the values will not be modified) in the network;
     * if no parameters array is specified, parameters will be initialized randomly according to the network configuration.
     *
     * @param parameters           Network parameter. May be null. If null: randomly initialize.
     * @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used directly
     */
public void init(INDArray parameters, boolean cloneParametersArray) {
    if (initCalled)
        return;
    //First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients
    topologicalOrder = topologicalSortOrder();
    //Initialization: create the GraphVertex objects, based on configuration structure
    Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = configuration.getVertices();
    //Names of all of the (data) inputs to the ComputationGraph
    List<String> networkInputNames = configuration.getNetworkInputs();
    //Inputs for each layer and GraphNode:
    Map<String, List<String>> vertexInputs = configuration.getVertexInputs();
    this.vertices = new GraphVertex[networkInputNames.size() + configuration.getVertices().size()];
    //All names: inputs, layers and graph nodes (index to name map)
    Map<String, Integer> allNamesReverse = new HashMap<>();
    //Create network input vertices:
    int vertexNumber = 0;
    for (String name : networkInputNames) {
        //Output vertices: set later
        GraphVertex gv = new InputVertex(this, name, vertexNumber, null);
        allNamesReverse.put(name, vertexNumber);
        vertices[vertexNumber++] = gv;
    }
    //Go through layers, and work out total number of parameters. Then allocate full parameters array
    int numParams = 0;
    int[] numParamsForVertex = new int[topologicalOrder.length];
    int i = 0;
    for (; i < configuration.getNetworkInputs().size(); i++) {
        //No parameters for input vertices
        numParamsForVertex[i] = 0;
    }
    for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
        org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
        numParamsForVertex[i] = n.numParams(true);
        numParams += numParamsForVertex[i];
        i++;
    }
    boolean initializeParams;
    if (parameters != null) {
        if (!parameters.isRowVector())
            throw new IllegalArgumentException("Invalid parameters: should be a row vector");
        if (parameters.length() != numParams)
            throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length " + parameters.length());
        if (cloneParametersArray)
            flattenedParams = parameters.dup();
        else
            flattenedParams = parameters;
        initializeParams = false;
    } else {
        flattenedParams = Nd4j.create(1, numParams);
        initializeParams = true;
    }
    //Given the topological ordering: work out the subset of the parameters array used for each layer
    // Then extract out for use when initializing the Layers
    INDArray[] paramsViewForVertex = new INDArray[topologicalOrder.length];
    int paramOffsetSoFar = 0;
    i = 0;
    for (int vertexIdx : topologicalOrder) {
        int nParamsThisVertex = numParamsForVertex[vertexIdx];
        if (nParamsThisVertex != 0) {
            paramsViewForVertex[vertexIdx] = flattenedParams.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
        }
        i++;
        paramOffsetSoFar += nParamsThisVertex;
    }
    int numLayers = 0;
    List<Layer> tempLayerList = new ArrayList<>();
    defaultConfiguration.clearVariables();
    List<String> variables = defaultConfiguration.variables(false);
    for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
        org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
        String name = nodeEntry.getKey();
        GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber], initializeParams);
        if (gv.hasLayer()) {
            numLayers++;
            Layer l = gv.getLayer();
            tempLayerList.add(l);
            List<String> layerVariables = l.conf().variables();
            if (layerVariables != null) {
                for (String s : layerVariables) {
                    variables.add(gv.getVertexName() + "_" + s);
                }
            }
        }
        allNamesReverse.put(name, vertexNumber);
        vertices[vertexNumber++] = gv;
    }
    layers = tempLayerList.toArray(new Layer[numLayers]);
    //Create the lookup table, so we can find vertices easily by name
    verticesMap = new HashMap<>();
    for (GraphVertex gv : vertices) {
        verticesMap.put(gv.getVertexName(), gv);
    }
    //Now: do another pass to set the input and output indices, for each vertex
    // These indices are used during forward and backward passes
    //To get output indices: need to essentially build the graph in reverse...
    //Key: vertex. Values: vertices that this node is an input for
    Map<String, List<String>> verticesOutputTo = new HashMap<>();
    for (GraphVertex gv : vertices) {
        String vertexName = gv.getVertexName();
        List<String> vertexInputNames;
        vertexInputNames = vertexInputs.get(vertexName);
        if (vertexInputNames == null)
            continue;
        //Build reverse network structure:
        for (String s : vertexInputNames) {
            List<String> list = verticesOutputTo.get(s);
            if (list == null) {
                list = new ArrayList<>();
                verticesOutputTo.put(s, list);
            }
            //Edge: s -> vertexName
            list.add(vertexName);
        }
    }
    for (GraphVertex gv : vertices) {
        String vertexName = gv.getVertexName();
        int vertexIndex = gv.getVertexIndex();
        List<String> vertexInputNames;
        vertexInputNames = vertexInputs.get(vertexName);
        if (vertexInputNames == null)
            continue;
        VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
        for (int j = 0; j < vertexInputNames.size(); j++) {
            String inName = vertexInputNames.get(j);
            int inputVertexIndex = allNamesReverse.get(inName);
            //Output of vertex 'inputVertexIndex' is the jth input to the current vertex
            //For input indices, we need to know which output connection of vertex 'inputVertexIndex' this represents
            GraphVertex inputVertex = vertices[inputVertexIndex];
            //First: get the outputs of the input vertex...
            List<String> inputVertexOutputsTo = verticesOutputTo.get(inName);
            int outputNumberOfInput = inputVertexOutputsTo.indexOf(vertexName);
            if (outputNumberOfInput == -1)
                throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of outputs " + "for vertex " + inputVertex + "; error in graph structure?");
            //Overall here: the 'outputNumberOfInput'th output of vertex 'inputVertexIndex' is the jth input to the current vertex
            inputIndices[j] = new VertexIndices(inputVertexIndex, outputNumberOfInput);
        }
        gv.setInputVertices(inputIndices);
    }
    //Handle the outputs for this vertex
    for (GraphVertex gv : vertices) {
        String vertexName = gv.getVertexName();
        List<String> thisVertexOutputsTo = verticesOutputTo.get(vertexName);
        if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty())
            //Output vertex
            continue;
        VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
        int j = 0;
        for (String s : thisVertexOutputsTo) {
            //First, we have gv -> s
            //Which input in s does gv connect to? s may in general have multiple inputs...
            List<String> nextVertexInputNames = vertexInputs.get(s);
            int outputVertexInputNumber = nextVertexInputNames.indexOf(vertexName);
            int outputVertexIndex = allNamesReverse.get(s);
            outputIndices[j++] = new VertexIndices(outputVertexIndex, outputVertexInputNumber);
        }
        gv.setOutputVertices(outputIndices);
    }
    initCalled = true;
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) InputVertex(org.deeplearning4j.nn.graph.vertex.impl.InputVertex) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices) Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Aggregations

Layer (org.deeplearning4j.nn.api.Layer)1 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)1 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)1 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)1 GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)1 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)1 InputVertex (org.deeplearning4j.nn.graph.vertex.impl.InputVertex)1 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1