use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class ComputationGraph method pretrainLayer.
/**
* Pretrain a specified layer with the given MultiDataSetIterator
*
* @param layerName Layer name
* @param iter Training data
*/
public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
if (!configuration.isPretrain())
return;
if (flattenedGradients == null)
initGradientsView();
if (!verticesMap.containsKey(layerName)) {
throw new IllegalStateException("Invalid vertex name: " + layerName);
}
if (!verticesMap.get(layerName).hasLayer()) {
//No op
return;
}
int layerIndex = verticesMap.get(layerName).getVertexIndex();
//Need to do partial forward pass. Simply folowing the topological ordering won't be efficient, as we might
// end up doing forward pass on layers we don't need to.
//However, we can start with the topological order, and prune out any layers we don't need to do
LinkedList<Integer> partialTopoSort = new LinkedList<>();
Set<Integer> seenSoFar = new HashSet<>();
partialTopoSort.add(topologicalOrder[layerIndex]);
seenSoFar.add(topologicalOrder[layerIndex]);
for (int j = layerIndex - 1; j >= 0; j--) {
//Do we need to do forward pass on this GraphVertex?
//If it is input to any other layer we need, then yes. Otherwise: no
VertexIndices[] outputsTo = vertices[topologicalOrder[j]].getOutputVertices();
boolean needed = false;
for (VertexIndices vi : outputsTo) {
if (seenSoFar.contains(vi.getVertexIndex())) {
needed = true;
break;
}
}
if (needed) {
partialTopoSort.addFirst(topologicalOrder[j]);
seenSoFar.add(topologicalOrder[j]);
}
}
int[] fwdPassOrder = new int[partialTopoSort.size()];
int k = 0;
for (Integer g : partialTopoSort) fwdPassOrder[k++] = g;
GraphVertex gv = vertices[fwdPassOrder[fwdPassOrder.length - 1]];
Layer layer = gv.getLayer();
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
while (iter.hasNext()) {
MultiDataSet multiDataSet = iter.next();
setInputs(multiDataSet.getFeatures());
for (int j = 0; j < fwdPassOrder.length - 1; j++) {
GraphVertex current = vertices[fwdPassOrder[j]];
if (current.isInputVertex()) {
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray input = inputs[current.getVertexIndex()];
for (VertexIndices v : inputsTo) {
int vIdx = v.getVertexIndex();
int vIdxInputNum = v.getVertexEdgeNumber();
//This input: the 'vIdxInputNum'th input to vertex 'vIdx'
//TODO When to dup?
vertices[vIdx].setInput(vIdxInputNum, input.dup());
}
} else {
//Do forward pass:
INDArray out = current.doForward(true);
//Now, set the inputs for the next vertices:
VertexIndices[] outputsTo = current.getOutputVertices();
if (outputsTo != null) {
for (VertexIndices v : outputsTo) {
int vIdx = v.getVertexIndex();
int inputNum = v.getVertexEdgeNumber();
//This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
vertices[vIdx].setInput(inputNum, out);
}
}
}
}
//At this point: have done all of the required forward pass stuff. Can now pretrain layer on current input
layer.fit(gv.getInputs()[0]);
layer.conf().setPretrain(false);
}
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class FlowIterationListener method flattenToY.
/**
* This method returns all Layers connected to the currentInput
*
* @param vertices
* @param currentInput
* @param currentY
* @return
*/
protected List<LayerInfo> flattenToY(ModelInfo model, GraphVertex[] vertices, List<String> currentInput, int currentY) {
List<LayerInfo> results = new ArrayList<>();
int x = 0;
for (int v = 0; v < vertices.length; v++) {
GraphVertex vertex = vertices[v];
VertexIndices[] indices = vertex.getInputVertices();
if (indices != null)
for (int i = 0; i < indices.length; i++) {
GraphVertex cv = vertices[indices[i].getVertexIndex()];
String inputName = cv.getVertexName();
for (String input : currentInput) {
if (inputName.equals(input)) {
// log.info("Vertex: " + vertex.getVertexName() + " has Input: " + input);
try {
LayerInfo info = model.getLayerInfoByName(vertex.getVertexName());
if (info == null)
info = getLayerInfo(vertex.getLayer(), x, currentY, 121);
info.setName(vertex.getVertexName());
// special case here: vertex isn't a layer
if (vertex.getLayer() == null) {
info.setLayerType(vertex.getClass().getSimpleName());
}
if (info.getName().endsWith("-merge"))
info.setLayerType("MERGE");
if (model.getLayerInfoByName(vertex.getVertexName()) == null) {
x++;
model.addLayer(info);
results.add(info);
}
// now we should map connections
LayerInfo connection = model.getLayerInfoByName(input);
if (connection != null) {
connection.addConnection(info);
// log.info("Adding connection ["+ connection.getName()+"] -> ["+ info.getName()+"]");
} else {
// the only reason to have null here, is direct input connection
//connection.addConnection(0,0);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
}
return results;
}
Aggregations