Search in sources :

Example 21 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraph method pretrainLayer.

/**
     * Pretrain a specified layer with the given MultiDataSetIterator
     *
     * @param layerName       Layer name
     * @param iter Training data
     */
public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
    if (!configuration.isPretrain())
        return;
    if (flattenedGradients == null)
        initGradientsView();
    if (!verticesMap.containsKey(layerName)) {
        throw new IllegalStateException("Invalid vertex name: " + layerName);
    }
    if (!verticesMap.get(layerName).hasLayer()) {
        //No op
        return;
    }
    int layerIndex = verticesMap.get(layerName).getVertexIndex();
    //Need to do partial forward pass. Simply folowing the topological ordering won't be efficient, as we might
    // end up doing forward pass on layers we don't need to.
    //However, we can start with the topological order, and prune out any layers we don't need to do
    LinkedList<Integer> partialTopoSort = new LinkedList<>();
    Set<Integer> seenSoFar = new HashSet<>();
    partialTopoSort.add(topologicalOrder[layerIndex]);
    seenSoFar.add(topologicalOrder[layerIndex]);
    for (int j = layerIndex - 1; j >= 0; j--) {
        //Do we need to do forward pass on this GraphVertex?
        //If it is input to any other layer we need, then yes. Otherwise: no
        VertexIndices[] outputsTo = vertices[topologicalOrder[j]].getOutputVertices();
        boolean needed = false;
        for (VertexIndices vi : outputsTo) {
            if (seenSoFar.contains(vi.getVertexIndex())) {
                needed = true;
                break;
            }
        }
        if (needed) {
            partialTopoSort.addFirst(topologicalOrder[j]);
            seenSoFar.add(topologicalOrder[j]);
        }
    }
    int[] fwdPassOrder = new int[partialTopoSort.size()];
    int k = 0;
    for (Integer g : partialTopoSort) fwdPassOrder[k++] = g;
    GraphVertex gv = vertices[fwdPassOrder[fwdPassOrder.length - 1]];
    Layer layer = gv.getLayer();
    if (!iter.hasNext() && iter.resetSupported()) {
        iter.reset();
    }
    while (iter.hasNext()) {
        MultiDataSet multiDataSet = iter.next();
        setInputs(multiDataSet.getFeatures());
        for (int j = 0; j < fwdPassOrder.length - 1; j++) {
            GraphVertex current = vertices[fwdPassOrder[j]];
            if (current.isInputVertex()) {
                VertexIndices[] inputsTo = current.getOutputVertices();
                INDArray input = inputs[current.getVertexIndex()];
                for (VertexIndices v : inputsTo) {
                    int vIdx = v.getVertexIndex();
                    int vIdxInputNum = v.getVertexEdgeNumber();
                    //This input: the 'vIdxInputNum'th input to vertex 'vIdx'
                    //TODO When to dup?
                    vertices[vIdx].setInput(vIdxInputNum, input.dup());
                }
            } else {
                //Do forward pass:
                INDArray out = current.doForward(true);
                //Now, set the inputs for the next vertices:
                VertexIndices[] outputsTo = current.getOutputVertices();
                if (outputsTo != null) {
                    for (VertexIndices v : outputsTo) {
                        int vIdx = v.getVertexIndex();
                        int inputNum = v.getVertexEdgeNumber();
                        //This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
                        vertices[vIdx].setInput(inputNum, out);
                    }
                }
            }
        }
        //At this point: have done all of the required forward pass stuff. Can now pretrain layer on current input
        layer.fit(gv.getInputs()[0]);
        layer.conf().setPretrain(false);
    }
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) RecurrentLayer(org.deeplearning4j.nn.api.layers.RecurrentLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) IOutputLayer(org.deeplearning4j.nn.api.layers.IOutputLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices)

Example 22 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class FlowIterationListener method flattenToY.

/**
     * This method returns all Layers connected to the currentInput
     *
     * @param vertices
     * @param currentInput
     * @param currentY
     * @return
     */
protected List<LayerInfo> flattenToY(ModelInfo model, GraphVertex[] vertices, List<String> currentInput, int currentY) {
    List<LayerInfo> results = new ArrayList<>();
    int x = 0;
    for (int v = 0; v < vertices.length; v++) {
        GraphVertex vertex = vertices[v];
        VertexIndices[] indices = vertex.getInputVertices();
        if (indices != null)
            for (int i = 0; i < indices.length; i++) {
                GraphVertex cv = vertices[indices[i].getVertexIndex()];
                String inputName = cv.getVertexName();
                for (String input : currentInput) {
                    if (inputName.equals(input)) {
                        //    log.info("Vertex: " + vertex.getVertexName() + " has Input: " + input);
                        try {
                            LayerInfo info = model.getLayerInfoByName(vertex.getVertexName());
                            if (info == null)
                                info = getLayerInfo(vertex.getLayer(), x, currentY, 121);
                            info.setName(vertex.getVertexName());
                            // special case here: vertex isn't a layer
                            if (vertex.getLayer() == null) {
                                info.setLayerType(vertex.getClass().getSimpleName());
                            }
                            if (info.getName().endsWith("-merge"))
                                info.setLayerType("MERGE");
                            if (model.getLayerInfoByName(vertex.getVertexName()) == null) {
                                x++;
                                model.addLayer(info);
                                results.add(info);
                            }
                            // now we should map connections
                            LayerInfo connection = model.getLayerInfoByName(input);
                            if (connection != null) {
                                connection.addConnection(info);
                            //  log.info("Adding connection ["+ connection.getName()+"] -> ["+ info.getName()+"]");
                            } else {
                            // the only reason to have null here, is direct input connection
                            //connection.addConnection(0,0);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
    }
    return results;
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) VertexIndices(org.deeplearning4j.nn.graph.vertex.VertexIndices)

Aggregations

GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)22 INDArray (org.nd4j.linalg.api.ndarray.INDArray)19 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)9 Test (org.junit.Test)9 Layer (org.deeplearning4j.nn.api.Layer)8 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)8 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)8 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)7 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)6 Gradient (org.deeplearning4j.nn.gradient.Gradient)4 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 BaseOutputLayer (org.deeplearning4j.nn.conf.layers.BaseOutputLayer)2 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)2 Pair (org.deeplearning4j.berkeley.Pair)1 Triple (org.deeplearning4j.berkeley.Triple)1 MaskState (org.deeplearning4j.nn.api.MaskState)1 DuplicateToTimeSeriesVertex (org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex)1