Search in sources :

Example 6 with GraphVertex

use of org.deeplearning4j.nn.conf.graph.GraphVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfiguration method fromJson.

/**
     * Create a computation graph configuration from json
     *
     * @param json the neural net configuration from json
     * @return {@link ComputationGraphConfiguration}
     */
public static ComputationGraphConfiguration fromJson(String json) {
    //As per MultiLayerConfiguration.fromJson()
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
    ComputationGraphConfiguration conf;
    try {
        conf = mapper.readValue(json, ComputationGraphConfiguration.class);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    //To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
    // Previously: enumeration used for activation functions. Now: use classes
    int layerCount = 0;
    Map<String, GraphVertex> vertexMap = conf.getVertices();
    JsonNode vertices = null;
    for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) {
        if (!(entry.getValue() instanceof LayerVertex)) {
            continue;
        }
        LayerVertex lv = (LayerVertex) entry.getValue();
        if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) {
            Layer layer = lv.getLayerConf().getLayer();
            if (layer.getActivationFn() == null) {
                String layerName = layer.getLayerName();
                try {
                    if (vertices == null) {
                        JsonNode jsonNode = mapper.readTree(json);
                        vertices = jsonNode.get("vertices");
                    }
                    JsonNode vertexNode = vertices.get(layerName);
                    JsonNode layerVertexNode = vertexNode.get("LayerVertex");
                    if (layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")) {
                        continue;
                    }
                    JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
                    if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                        continue;
                    }
                    JsonNode layerNode = layerWrapperNode.elements().next();
                    //Should only have 1 element: "dense", "output", etc
                    JsonNode activationFunction = layerNode.get("activationFunction");
                    if (activationFunction != null) {
                        IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
                        layer.setActivationFn(ia);
                    }
                } catch (IOException e) {
                    log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
                }
            }
        }
    }
    return conf;
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) JsonNode(org.nd4j.shade.jackson.databind.JsonNode) IOException(java.io.IOException) IActivation(org.nd4j.linalg.activations.IActivation) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper)

Aggregations

GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)6 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)6 Layer (org.deeplearning4j.nn.conf.layers.Layer)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 Updater (org.deeplearning4j.nn.conf.Updater)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 IActivation (org.nd4j.linalg.activations.IActivation)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 IOException (java.io.IOException)1 Field (java.lang.reflect.Field)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 InputType (org.deeplearning4j.nn.conf.inputs.InputType)1