Search in sources :

Example 6 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testMergeNodeRNN.

@Test
public void testMergeNodeRNN() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex mergeNode = new MergeVertex(null, "", -1);
    INDArray first = Nd4j.linspace(0, 59, 60).reshape(3, 4, 5);
    INDArray second = Nd4j.linspace(0, 89, 90).reshape(3, 6, 5).addi(100);
    mergeNode.setInputs(first, second);
    INDArray out = mergeNode.doForward(false);
    assertArrayEquals(new int[] { 3, 10, 5 }, out.shape());
    assertEquals(first, out.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4), NDArrayIndex.all()));
    assertEquals(second, out.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 10), NDArrayIndex.all()));
    mergeNode.setEpsilon(out);
    INDArray[] backward = mergeNode.doBackward(false).getSecond();
    assertEquals(first, backward[0]);
    assertEquals(second, backward[1]);
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Example 7 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class RemoteFlowIterationListener method buildModelInfo.

protected ModelInfo buildModelInfo(Model model) {
    ModelInfo modelInfo = new ModelInfo();
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        /*
                we assume that graph starts on input. every layer connected to input - is on y1
                every layer connected to y1, is on y2 etc.
              */
        List<String> inputs = graph.getConfiguration().getNetworkInputs();
        // now we need to add inputs as y0 nodes
        int x = 0;
        for (String input : inputs) {
            GraphVertex vertex = graph.getVertex(input);
            INDArray gInput = vertex.getInputs()[0];
            long tadLength = Shape.getTADLength(gInput.shape(), ArrayUtil.range(1, gInput.rank()));
            long numSamples = gInput.lengthLong() / tadLength;
            StringBuilder builder = new StringBuilder();
            builder.append("Vertex name: ").append(input).append("<br/>");
            builder.append("Model input").append("<br/>");
            builder.append("Input size: ").append(tadLength).append("<br/>");
            builder.append("Batch size: ").append(numSamples).append("<br/>");
            LayerInfo info = new LayerInfo();
            info.setId(0);
            info.setName(input);
            info.setY(0);
            info.setX(x);
            info.setLayerType(INPUT);
            info.setDescription(new Description());
            info.getDescription().setMainLine("Model input");
            info.getDescription().setText(builder.toString());
            modelInfo.addLayer(info);
            x++;
        }
        GraphVertex[] vertices = graph.getVertices();
        // filling grid in LTR/TTB direction
        List<String> needle = new ArrayList<>();
        // we assume that max row can't be higher then total number of vertices
        for (int y = 1; y < vertices.length; y++) {
            if (needle.isEmpty())
                needle.addAll(inputs);
            /*
                    for each grid row we look for nodes, that are connected to previous layer
                */
            List<LayerInfo> layersForGridY = flattenToY(modelInfo, vertices, needle, y);
            needle.clear();
            for (LayerInfo layerInfo : layersForGridY) {
                needle.add(layerInfo.getName());
            }
            if (needle.isEmpty())
                break;
        }
    } else if (model instanceof MultiLayerNetwork) {
        MultiLayerNetwork network = (MultiLayerNetwork) model;
        // manually adding input layer
        INDArray input = model.input();
        long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
        long numSamples = input.lengthLong() / tadLength;
        StringBuilder builder = new StringBuilder();
        builder.append("Model input").append("<br/>");
        builder.append("Input size: ").append(tadLength).append("<br/>");
        builder.append("Batch size: ").append(numSamples).append("<br/>");
        LayerInfo info = new LayerInfo();
        info.setId(0);
        info.setName("Input");
        info.setY(0);
        info.setX(0);
        info.setLayerType(INPUT);
        info.setDescription(new Description());
        info.getDescription().setMainLine("Model input");
        info.getDescription().setText(builder.toString());
        info.addConnection(0, 1);
        modelInfo.addLayer(info);
        // entry 0 is reserved for inputs
        int y = 1;
        // for MLN x value is always 0
        final int x = 0;
        for (Layer layer : network.getLayers()) {
            LayerInfo layerInfo = getLayerInfo(layer, x, y, y);
            // since it's MLN, we know connections in advance as curLayer + 1
            layerInfo.addConnection(x, y + 1);
            modelInfo.addLayer(layerInfo);
            y++;
        }
        LayerInfo layerInfo = modelInfo.getLayerInfoByCoords(x, y - 1);
        layerInfo.dropConnections();
    }
    // find layers without connections, and mark them as output layers
    for (LayerInfo layerInfo : modelInfo.getLayers()) {
        if (layerInfo.getConnections().size() == 0)
            layerInfo.setLayerType("OUTPUT");
    }
    // now we apply colors to distinct layer types
    AtomicInteger cnt = new AtomicInteger(0);
    for (String layerType : modelInfo.getLayerTypes()) {
        String curColor = colors.get(cnt.getAndIncrement());
        if (cnt.get() >= colors.size())
            cnt.set(0);
        for (LayerInfo layerInfo : modelInfo.getLayersByType(layerType)) {
            if (layerType.equals(INPUT)) {
                layerInfo.setColor("#99ff66");
            } else if (layerType.equals("OUTPUT")) {
                layerInfo.setColor("#e6e6e6");
            } else {
                layerInfo.setColor(curColor);
            }
        }
    }
    return modelInfo;
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) BaseOutputLayer(org.deeplearning4j.nn.conf.layers.BaseOutputLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 8 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class RemoteFlowIterationListener method flattenToY.

/**
     * This method returns all Layers connected to the currentInput
     *
     * @param vertices
     * @param currentInput
     * @param currentY
     * @return
     */
protected List<LayerInfo> flattenToY(ModelInfo model, GraphVertex[] vertices, List<String> currentInput, int currentY) {
    List<LayerInfo> results = new ArrayList<>();
    int x = 0;
    for (int v = 0; v < vertices.length; v++) {
        GraphVertex vertex = vertices[v];
        VertexIndices[] indices = vertex.getInputVertices();
        if (indices != null)
            for (int i = 0; i < indices.length; i++) {
                GraphVertex cv = vertices[indices[i].getVertexIndex()];
                String inputName = cv.getVertexName();
                for (String input : currentInput) {
                    if (inputName.equals(input)) {
                        //    log.info("Vertex: " + vertex.getVertexName() + " has Input: " + input);
                        try {
                            LayerInfo info = model.getLayerInfoByName(vertex.getVertexName());
                            if (info == null)
                                info = getLayerInfo(vertex.getLayer(), x, currentY, 121);
                            info.setName(vertex.getVertexName());
                            // special case here: vertex isn't a layer
                            if (vertex.getLayer() == null) {
                                info.setLayerType(vertex.getClass().getSimpleName());
                            }
                            if (info.getName().endsWith("-merge"))
                                info.setLayerType("MERGE");
                            if (model.getLayerInfoByName(vertex.getVertexName()) == null) {
                                x++;
                                model.addLayer(info);
                                results.add(info);
                            }
                            // now we should map connections
                            LayerInfo connection = model.getLayerInfoByName(input);
                            if (connection != null) {
                                connection.addConnection(info);
                            //  log.info("Adding connection ["+ connection.getName()+"] -> ["+ info.getName()+"]");
                            } else {
                            // the only reason to have null here, is direct input connection
                            //connection.addConnection(0,0);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
    }
    return results;
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices)

Example 9 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class FlowIterationListener method buildModelInfo.

protected ModelInfo buildModelInfo(Model model) {
    ModelInfo modelInfo = new ModelInfo();
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        /*
                we assume that graph starts on input. every layer connected to input - is on y1
                every layer connected to y1, is on y2 etc.
              */
        List<String> inputs = graph.getConfiguration().getNetworkInputs();
        // now we need to add inputs as y0 nodes
        int x = 0;
        for (String input : inputs) {
            GraphVertex vertex = graph.getVertex(input);
            long numSamples;
            long tadLength;
            if (vertex.getInputs() == null || vertex.getInputs().length == 0) {
                numSamples = 0;
                tadLength = 0;
            } else {
                INDArray gInput = vertex.getInputs()[0];
                tadLength = Shape.getTADLength(gInput.shape(), ArrayUtil.range(1, gInput.rank()));
                numSamples = gInput.lengthLong() / tadLength;
            }
            StringBuilder builder = new StringBuilder();
            builder.append("Vertex name: ").append(input).append("<br/>");
            builder.append("Model input").append("<br/>");
            builder.append("Input size: ").append(tadLength).append("<br/>");
            builder.append("Batch size: ").append(numSamples).append("<br/>");
            LayerInfo info = new LayerInfo();
            info.setId(0);
            info.setName(input);
            info.setY(0);
            info.setX(x);
            info.setLayerType(INPUT);
            info.setDescription(new Description());
            info.getDescription().setMainLine("Model input");
            info.getDescription().setText(builder.toString());
            modelInfo.addLayer(info);
            x++;
        }
        GraphVertex[] vertices = graph.getVertices();
        // filling grid in LTR/TTB direction
        List<String> needle = new ArrayList<>();
        // we assume that max row can't be higher then total number of vertices
        for (int y = 1; y < vertices.length; y++) {
            if (needle.isEmpty())
                needle.addAll(inputs);
            /*
                    for each grid row we look for nodes, that are connected to previous layer
                */
            List<LayerInfo> layersForGridY = flattenToY(modelInfo, vertices, needle, y);
            needle.clear();
            for (LayerInfo layerInfo : layersForGridY) {
                needle.add(layerInfo.getName());
            }
            if (needle.isEmpty())
                break;
        }
    } else if (model instanceof MultiLayerNetwork) {
        MultiLayerNetwork network = (MultiLayerNetwork) model;
        // manually adding input layer
        INDArray input = model.input();
        long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
        long numSamples = input.lengthLong() / tadLength;
        StringBuilder builder = new StringBuilder();
        builder.append("Model input").append("<br/>");
        builder.append("Input size: ").append(tadLength).append("<br/>");
        builder.append("Batch size: ").append(numSamples).append("<br/>");
        LayerInfo info = new LayerInfo();
        info.setId(0);
        info.setName("Input");
        info.setY(0);
        info.setX(0);
        info.setLayerType(INPUT);
        info.setDescription(new Description());
        info.getDescription().setMainLine("Model input");
        info.getDescription().setText(builder.toString());
        info.addConnection(0, 1);
        modelInfo.addLayer(info);
        // entry 0 is reserved for inputs
        int y = 1;
        // for MLN x value is always 0
        final int x = 0;
        for (Layer layer : network.getLayers()) {
            LayerInfo layerInfo = getLayerInfo(layer, x, y, y);
            // since it's MLN, we know connections in advance as curLayer + 1
            layerInfo.addConnection(x, y + 1);
            modelInfo.addLayer(layerInfo);
            y++;
        }
        LayerInfo layerInfo = modelInfo.getLayerInfoByCoords(x, y - 1);
        layerInfo.dropConnections();
    }
    // find layers without connections, and mark them as output layers
    for (LayerInfo layerInfo : modelInfo.getLayers()) {
        if (layerInfo.getConnections().size() == 0)
            layerInfo.setLayerType("OUTPUT");
    }
    // now we apply colors to distinct layer types
    AtomicInteger cnt = new AtomicInteger(0);
    for (String layerType : modelInfo.getLayerTypes()) {
        String curColor = colors.get(cnt.getAndIncrement());
        if (cnt.get() >= colors.size())
            cnt.set(0);
        for (LayerInfo layerInfo : modelInfo.getLayersByType(layerType)) {
            if (layerType.equals(INPUT)) {
                layerInfo.setColor("#99ff66");
            } else if (layerType.equals("OUTPUT")) {
                layerInfo.setColor("#e6e6e6");
            } else {
                layerInfo.setColor(curColor);
            }
        }
    }
    return modelInfo;
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) BaseOutputLayer(org.deeplearning4j.nn.conf.layers.BaseOutputLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 10 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testCnnDepthMerge.

@Test
public void testCnnDepthMerge() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex mergeNode = new MergeVertex(null, "", -1);
    INDArray first = Nd4j.linspace(0, 3, 4).reshape(1, 1, 2, 2);
    INDArray second = Nd4j.linspace(0, 3, 4).reshape(1, 1, 2, 2).addi(10);
    mergeNode.setInputs(first, second);
    INDArray out = mergeNode.doForward(false);
    assertArrayEquals(new int[] { 1, 2, 2, 2 }, out.shape());
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            assertEquals(first.getDouble(0, 0, i, j), out.getDouble(0, 0, i, j), 1e-6);
            assertEquals(second.getDouble(0, 0, i, j), out.getDouble(0, 1, i, j), 1e-6);
        }
    }
    mergeNode.setEpsilon(out);
    INDArray[] backward = mergeNode.doBackward(false).getSecond();
    assertEquals(first, backward[0]);
    assertEquals(second, backward[1]);
    //Slightly more complicated test:
    first = Nd4j.linspace(0, 17, 18).reshape(1, 2, 3, 3);
    second = Nd4j.linspace(0, 17, 18).reshape(1, 2, 3, 3).addi(100);
    mergeNode.setInputs(first, second);
    out = mergeNode.doForward(false);
    assertArrayEquals(new int[] { 1, 4, 3, 3 }, out.shape());
    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < 3; j++) {
            assertEquals(first.getDouble(0, 0, i, j), out.getDouble(0, 0, i, j), 1e-6);
            assertEquals(first.getDouble(0, 1, i, j), out.getDouble(0, 1, i, j), 1e-6);
            assertEquals(second.getDouble(0, 0, i, j), out.getDouble(0, 2, i, j), 1e-6);
            assertEquals(second.getDouble(0, 1, i, j), out.getDouble(0, 3, i, j), 1e-6);
        }
    }
    mergeNode.setEpsilon(out);
    backward = mergeNode.doBackward(false).getSecond();
    assertEquals(first, backward[0]);
    assertEquals(second, backward[1]);
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Aggregations

GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)22 INDArray (org.nd4j.linalg.api.ndarray.INDArray)19 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)9 Test (org.junit.Test)9 Layer (org.deeplearning4j.nn.api.Layer)8 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)8 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)8 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)7 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)6 Gradient (org.deeplearning4j.nn.gradient.Gradient)4 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 BaseOutputLayer (org.deeplearning4j.nn.conf.layers.BaseOutputLayer)2 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)2 Pair (org.deeplearning4j.berkeley.Pair)1 Triple (org.deeplearning4j.berkeley.Triple)1 MaskState (org.deeplearning4j.nn.api.MaskState)1 DuplicateToTimeSeriesVertex (org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex)1