use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testMergeNodeRNN.
@Test
public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1);
INDArray first = Nd4j.linspace(0, 59, 60).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90).reshape(3, 6, 5).addi(100);
mergeNode.setInputs(first, second);
INDArray out = mergeNode.doForward(false);
assertArrayEquals(new int[] { 3, 10, 5 }, out.shape());
assertEquals(first, out.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
assertEquals(second, out.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
mergeNode.setEpsilon(out);
INDArray[] backward = mergeNode.doBackward(false).getSecond();
assertEquals(first, backward[0]);
assertEquals(second, backward[1]);
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class RemoteFlowIterationListener method buildModelInfo.
protected ModelInfo buildModelInfo(Model model) {
ModelInfo modelInfo = new ModelInfo();
if (model instanceof ComputationGraph) {
ComputationGraph graph = (ComputationGraph) model;
/*
we assume that graph starts on input. every layer connected to input - is on y1
every layer connected to y1, is on y2 etc.
*/
List<String> inputs = graph.getConfiguration().getNetworkInputs();
// now we need to add inputs as y0 nodes
int x = 0;
for (String input : inputs) {
GraphVertex vertex = graph.getVertex(input);
INDArray gInput = vertex.getInputs()[0];
long tadLength = Shape.getTADLength(gInput.shape(), ArrayUtil.range(1, gInput.rank()));
long numSamples = gInput.lengthLong() / tadLength;
StringBuilder builder = new StringBuilder();
builder.append("Vertex name: ").append(input).append("<br/>");
builder.append("Model input").append("<br/>");
builder.append("Input size: ").append(tadLength).append("<br/>");
builder.append("Batch size: ").append(numSamples).append("<br/>");
LayerInfo info = new LayerInfo();
info.setId(0);
info.setName(input);
info.setY(0);
info.setX(x);
info.setLayerType(INPUT);
info.setDescription(new Description());
info.getDescription().setMainLine("Model input");
info.getDescription().setText(builder.toString());
modelInfo.addLayer(info);
x++;
}
GraphVertex[] vertices = graph.getVertices();
// filling grid in LTR/TTB direction
List<String> needle = new ArrayList<>();
// we assume that max row can't be higher then total number of vertices
for (int y = 1; y < vertices.length; y++) {
if (needle.isEmpty())
needle.addAll(inputs);
/*
for each grid row we look for nodes, that are connected to previous layer
*/
List<LayerInfo> layersForGridY = flattenToY(modelInfo, vertices, needle, y);
needle.clear();
for (LayerInfo layerInfo : layersForGridY) {
needle.add(layerInfo.getName());
}
if (needle.isEmpty())
break;
}
} else if (model instanceof MultiLayerNetwork) {
MultiLayerNetwork network = (MultiLayerNetwork) model;
// manually adding input layer
INDArray input = model.input();
long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
long numSamples = input.lengthLong() / tadLength;
StringBuilder builder = new StringBuilder();
builder.append("Model input").append("<br/>");
builder.append("Input size: ").append(tadLength).append("<br/>");
builder.append("Batch size: ").append(numSamples).append("<br/>");
LayerInfo info = new LayerInfo();
info.setId(0);
info.setName("Input");
info.setY(0);
info.setX(0);
info.setLayerType(INPUT);
info.setDescription(new Description());
info.getDescription().setMainLine("Model input");
info.getDescription().setText(builder.toString());
info.addConnection(0, 1);
modelInfo.addLayer(info);
// entry 0 is reserved for inputs
int y = 1;
// for MLN x value is always 0
final int x = 0;
for (Layer layer : network.getLayers()) {
LayerInfo layerInfo = getLayerInfo(layer, x, y, y);
// since it's MLN, we know connections in advance as curLayer + 1
layerInfo.addConnection(x, y + 1);
modelInfo.addLayer(layerInfo);
y++;
}
LayerInfo layerInfo = modelInfo.getLayerInfoByCoords(x, y - 1);
layerInfo.dropConnections();
}
// find layers without connections, and mark them as output layers
for (LayerInfo layerInfo : modelInfo.getLayers()) {
if (layerInfo.getConnections().size() == 0)
layerInfo.setLayerType("OUTPUT");
}
// now we apply colors to distinct layer types
AtomicInteger cnt = new AtomicInteger(0);
for (String layerType : modelInfo.getLayerTypes()) {
String curColor = colors.get(cnt.getAndIncrement());
if (cnt.get() >= colors.size())
cnt.set(0);
for (LayerInfo layerInfo : modelInfo.getLayersByType(layerType)) {
if (layerType.equals(INPUT)) {
layerInfo.setColor("#99ff66");
} else if (layerType.equals("OUTPUT")) {
layerInfo.setColor("#e6e6e6");
} else {
layerInfo.setColor(curColor);
}
}
}
return modelInfo;
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class RemoteFlowIterationListener method flattenToY.
/**
* This method returns all Layers connected to the currentInput
*
* @param vertices
* @param currentInput
* @param currentY
* @return
*/
protected List<LayerInfo> flattenToY(ModelInfo model, GraphVertex[] vertices, List<String> currentInput, int currentY) {
List<LayerInfo> results = new ArrayList<>();
int x = 0;
for (int v = 0; v < vertices.length; v++) {
GraphVertex vertex = vertices[v];
VertexIndices[] indices = vertex.getInputVertices();
if (indices != null)
for (int i = 0; i < indices.length; i++) {
GraphVertex cv = vertices[indices[i].getVertexIndex()];
String inputName = cv.getVertexName();
for (String input : currentInput) {
if (inputName.equals(input)) {
// log.info("Vertex: " + vertex.getVertexName() + " has Input: " + input);
try {
LayerInfo info = model.getLayerInfoByName(vertex.getVertexName());
if (info == null)
info = getLayerInfo(vertex.getLayer(), x, currentY, 121);
info.setName(vertex.getVertexName());
// special case here: vertex isn't a layer
if (vertex.getLayer() == null) {
info.setLayerType(vertex.getClass().getSimpleName());
}
if (info.getName().endsWith("-merge"))
info.setLayerType("MERGE");
if (model.getLayerInfoByName(vertex.getVertexName()) == null) {
x++;
model.addLayer(info);
results.add(info);
}
// now we should map connections
LayerInfo connection = model.getLayerInfoByName(input);
if (connection != null) {
connection.addConnection(info);
// log.info("Adding connection ["+ connection.getName()+"] -> ["+ info.getName()+"]");
} else {
// the only reason to have null here, is direct input connection
//connection.addConnection(0,0);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
}
return results;
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class FlowIterationListener method buildModelInfo.
protected ModelInfo buildModelInfo(Model model) {
ModelInfo modelInfo = new ModelInfo();
if (model instanceof ComputationGraph) {
ComputationGraph graph = (ComputationGraph) model;
/*
we assume that graph starts on input. every layer connected to input - is on y1
every layer connected to y1, is on y2 etc.
*/
List<String> inputs = graph.getConfiguration().getNetworkInputs();
// now we need to add inputs as y0 nodes
int x = 0;
for (String input : inputs) {
GraphVertex vertex = graph.getVertex(input);
long numSamples;
long tadLength;
if (vertex.getInputs() == null || vertex.getInputs().length == 0) {
numSamples = 0;
tadLength = 0;
} else {
INDArray gInput = vertex.getInputs()[0];
tadLength = Shape.getTADLength(gInput.shape(), ArrayUtil.range(1, gInput.rank()));
numSamples = gInput.lengthLong() / tadLength;
}
StringBuilder builder = new StringBuilder();
builder.append("Vertex name: ").append(input).append("<br/>");
builder.append("Model input").append("<br/>");
builder.append("Input size: ").append(tadLength).append("<br/>");
builder.append("Batch size: ").append(numSamples).append("<br/>");
LayerInfo info = new LayerInfo();
info.setId(0);
info.setName(input);
info.setY(0);
info.setX(x);
info.setLayerType(INPUT);
info.setDescription(new Description());
info.getDescription().setMainLine("Model input");
info.getDescription().setText(builder.toString());
modelInfo.addLayer(info);
x++;
}
GraphVertex[] vertices = graph.getVertices();
// filling grid in LTR/TTB direction
List<String> needle = new ArrayList<>();
// we assume that max row can't be higher then total number of vertices
for (int y = 1; y < vertices.length; y++) {
if (needle.isEmpty())
needle.addAll(inputs);
/*
for each grid row we look for nodes, that are connected to previous layer
*/
List<LayerInfo> layersForGridY = flattenToY(modelInfo, vertices, needle, y);
needle.clear();
for (LayerInfo layerInfo : layersForGridY) {
needle.add(layerInfo.getName());
}
if (needle.isEmpty())
break;
}
} else if (model instanceof MultiLayerNetwork) {
MultiLayerNetwork network = (MultiLayerNetwork) model;
// manually adding input layer
INDArray input = model.input();
long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
long numSamples = input.lengthLong() / tadLength;
StringBuilder builder = new StringBuilder();
builder.append("Model input").append("<br/>");
builder.append("Input size: ").append(tadLength).append("<br/>");
builder.append("Batch size: ").append(numSamples).append("<br/>");
LayerInfo info = new LayerInfo();
info.setId(0);
info.setName("Input");
info.setY(0);
info.setX(0);
info.setLayerType(INPUT);
info.setDescription(new Description());
info.getDescription().setMainLine("Model input");
info.getDescription().setText(builder.toString());
info.addConnection(0, 1);
modelInfo.addLayer(info);
// entry 0 is reserved for inputs
int y = 1;
// for MLN x value is always 0
final int x = 0;
for (Layer layer : network.getLayers()) {
LayerInfo layerInfo = getLayerInfo(layer, x, y, y);
// since it's MLN, we know connections in advance as curLayer + 1
layerInfo.addConnection(x, y + 1);
modelInfo.addLayer(layerInfo);
y++;
}
LayerInfo layerInfo = modelInfo.getLayerInfoByCoords(x, y - 1);
layerInfo.dropConnections();
}
// find layers without connections, and mark them as output layers
for (LayerInfo layerInfo : modelInfo.getLayers()) {
if (layerInfo.getConnections().size() == 0)
layerInfo.setLayerType("OUTPUT");
}
// now we apply colors to distinct layer types
AtomicInteger cnt = new AtomicInteger(0);
for (String layerType : modelInfo.getLayerTypes()) {
String curColor = colors.get(cnt.getAndIncrement());
if (cnt.get() >= colors.size())
cnt.set(0);
for (LayerInfo layerInfo : modelInfo.getLayersByType(layerType)) {
if (layerType.equals(INPUT)) {
layerInfo.setColor("#99ff66");
} else if (layerType.equals("OUTPUT")) {
layerInfo.setColor("#e6e6e6");
} else {
layerInfo.setColor(curColor);
}
}
}
return modelInfo;
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testCnnDepthMerge.
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1);
INDArray first = Nd4j.linspace(0, 3, 4).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4).reshape(1, 1, 2, 2).addi(10);
mergeNode.setInputs(first, second);
INDArray out = mergeNode.doForward(false);
assertArrayEquals(new int[] { 1, 2, 2, 2 }, out.shape());
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
assertEquals(first.getDouble(0, 0, i, j), out.getDouble(0, 0, i, j), 1e-6);
assertEquals(second.getDouble(0, 0, i, j), out.getDouble(0, 1, i, j), 1e-6);
}
}
mergeNode.setEpsilon(out);
INDArray[] backward = mergeNode.doBackward(false).getSecond();
assertEquals(first, backward[0]);
assertEquals(second, backward[1]);
//Slightly more complicated test:
first = Nd4j.linspace(0, 17, 18).reshape(1, 2, 3, 3);
second = Nd4j.linspace(0, 17, 18).reshape(1, 2, 3, 3).addi(100);
mergeNode.setInputs(first, second);
out = mergeNode.doForward(false);
assertArrayEquals(new int[] { 1, 4, 3, 3 }, out.shape());
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
assertEquals(first.getDouble(0, 0, i, j), out.getDouble(0, 0, i, j), 1e-6);
assertEquals(first.getDouble(0, 1, i, j), out.getDouble(0, 1, i, j), 1e-6);
assertEquals(second.getDouble(0, 0, i, j), out.getDouble(0, 2, i, j), 1e-6);
assertEquals(second.getDouble(0, 1, i, j), out.getDouble(0, 3, i, j), 1e-6);
}
}
mergeNode.setEpsilon(out);
backward = mergeNode.doBackward(false).getSecond();
assertEquals(first, backward[0]);
assertEquals(second, backward[1]);
}
Aggregations