use of org.deeplearning4j.nn.graph.vertex.impl.InputVertex in project deeplearning4j by deeplearning4j.
the class ComputationGraph method init.
/**
* Initialize the ComputationGraph, optionally with an existing parameters array.
* If an existing parameters array is specified, it will be used (and the values will not be modified) in the network;
* if no parameters array is specified, parameters will be initialized randomly according to the network configuration.
*
* @param parameters Network parameter. May be null. If null: randomly initialize.
* @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used directly
*/
public void init(INDArray parameters, boolean cloneParametersArray) {
if (initCalled)
return;
//First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients
topologicalOrder = topologicalSortOrder();
//Initialization: create the GraphVertex objects, based on configuration structure
Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> configVertexMap = configuration.getVertices();
//Names of all of the (data) inputs to the ComputationGraph
List<String> networkInputNames = configuration.getNetworkInputs();
//Inputs for each layer and GraphNode:
Map<String, List<String>> vertexInputs = configuration.getVertexInputs();
this.vertices = new GraphVertex[networkInputNames.size() + configuration.getVertices().size()];
//All names: inputs, layers and graph nodes (index to name map)
Map<String, Integer> allNamesReverse = new HashMap<>();
//Create network input vertices:
int vertexNumber = 0;
for (String name : networkInputNames) {
//Output vertices: set later
GraphVertex gv = new InputVertex(this, name, vertexNumber, null);
allNamesReverse.put(name, vertexNumber);
vertices[vertexNumber++] = gv;
}
//Go through layers, and work out total number of parameters. Then allocate full parameters array
int numParams = 0;
int[] numParamsForVertex = new int[topologicalOrder.length];
int i = 0;
for (; i < configuration.getNetworkInputs().size(); i++) {
//No parameters for input vertices
numParamsForVertex[i] = 0;
}
for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
numParamsForVertex[i] = n.numParams(true);
numParams += numParamsForVertex[i];
i++;
}
boolean initializeParams;
if (parameters != null) {
if (!parameters.isRowVector())
throw new IllegalArgumentException("Invalid parameters: should be a row vector");
if (parameters.length() != numParams)
throw new IllegalArgumentException("Invalid parameters: expected length " + numParams + ", got length " + parameters.length());
if (cloneParametersArray)
flattenedParams = parameters.dup();
else
flattenedParams = parameters;
initializeParams = false;
} else {
flattenedParams = Nd4j.create(1, numParams);
initializeParams = true;
}
//Given the topological ordering: work out the subset of the parameters array used for each layer
// Then extract out for use when initializing the Layers
INDArray[] paramsViewForVertex = new INDArray[topologicalOrder.length];
int paramOffsetSoFar = 0;
i = 0;
for (int vertexIdx : topologicalOrder) {
int nParamsThisVertex = numParamsForVertex[vertexIdx];
if (nParamsThisVertex != 0) {
paramsViewForVertex[vertexIdx] = flattenedParams.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramOffsetSoFar, paramOffsetSoFar + nParamsThisVertex));
}
i++;
paramOffsetSoFar += nParamsThisVertex;
}
int numLayers = 0;
List<Layer> tempLayerList = new ArrayList<>();
defaultConfiguration.clearVariables();
List<String> variables = defaultConfiguration.variables(false);
for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> nodeEntry : configVertexMap.entrySet()) {
org.deeplearning4j.nn.conf.graph.GraphVertex n = nodeEntry.getValue();
String name = nodeEntry.getKey();
GraphVertex gv = n.instantiate(this, name, vertexNumber, paramsViewForVertex[vertexNumber], initializeParams);
if (gv.hasLayer()) {
numLayers++;
Layer l = gv.getLayer();
tempLayerList.add(l);
List<String> layerVariables = l.conf().variables();
if (layerVariables != null) {
for (String s : layerVariables) {
variables.add(gv.getVertexName() + "_" + s);
}
}
}
allNamesReverse.put(name, vertexNumber);
vertices[vertexNumber++] = gv;
}
layers = tempLayerList.toArray(new Layer[numLayers]);
//Create the lookup table, so we can find vertices easily by name
verticesMap = new HashMap<>();
for (GraphVertex gv : vertices) {
verticesMap.put(gv.getVertexName(), gv);
}
//Now: do another pass to set the input and output indices, for each vertex
// These indices are used during forward and backward passes
//To get output indices: need to essentially build the graph in reverse...
//Key: vertex. Values: vertices that this node is an input for
Map<String, List<String>> verticesOutputTo = new HashMap<>();
for (GraphVertex gv : vertices) {
String vertexName = gv.getVertexName();
List<String> vertexInputNames;
vertexInputNames = vertexInputs.get(vertexName);
if (vertexInputNames == null)
continue;
//Build reverse network structure:
for (String s : vertexInputNames) {
List<String> list = verticesOutputTo.get(s);
if (list == null) {
list = new ArrayList<>();
verticesOutputTo.put(s, list);
}
//Edge: s -> vertexName
list.add(vertexName);
}
}
for (GraphVertex gv : vertices) {
String vertexName = gv.getVertexName();
int vertexIndex = gv.getVertexIndex();
List<String> vertexInputNames;
vertexInputNames = vertexInputs.get(vertexName);
if (vertexInputNames == null)
continue;
VertexIndices[] inputIndices = new VertexIndices[vertexInputNames.size()];
for (int j = 0; j < vertexInputNames.size(); j++) {
String inName = vertexInputNames.get(j);
int inputVertexIndex = allNamesReverse.get(inName);
//Output of vertex 'inputVertexIndex' is the jth input to the current vertex
//For input indices, we need to know which output connection of vertex 'inputVertexIndex' this represents
GraphVertex inputVertex = vertices[inputVertexIndex];
//First: get the outputs of the input vertex...
List<String> inputVertexOutputsTo = verticesOutputTo.get(inName);
int outputNumberOfInput = inputVertexOutputsTo.indexOf(vertexName);
if (outputNumberOfInput == -1)
throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of outputs " + "for vertex " + inputVertex + "; error in graph structure?");
//Overall here: the 'outputNumberOfInput'th output of vertex 'inputVertexIndex' is the jth input to the current vertex
inputIndices[j] = new VertexIndices(inputVertexIndex, outputNumberOfInput);
}
gv.setInputVertices(inputIndices);
}
//Handle the outputs for this vertex
for (GraphVertex gv : vertices) {
String vertexName = gv.getVertexName();
List<String> thisVertexOutputsTo = verticesOutputTo.get(vertexName);
if (thisVertexOutputsTo == null || thisVertexOutputsTo.isEmpty())
//Output vertex
continue;
VertexIndices[] outputIndices = new VertexIndices[thisVertexOutputsTo.size()];
int j = 0;
for (String s : thisVertexOutputsTo) {
//First, we have gv -> s
//Which input in s does gv connect to? s may in general have multiple inputs...
List<String> nextVertexInputNames = vertexInputs.get(s);
int outputVertexInputNumber = nextVertexInputNames.indexOf(vertexName);
int outputVertexIndex = allNamesReverse.get(s);
outputIndices[j++] = new VertexIndices(outputVertexIndex, outputVertexInputNumber);
}
gv.setOutputVertices(outputIndices);
}
initCalled = true;
}
Aggregations