Search in sources :

Example 1 with JsonNode

use of org.nd4j.shade.jackson.databind.JsonNode in project deeplearning4j by deeplearning4j.

the class MultiLayerConfiguration method fromJson.

/**
     * Create a neural net configuration from json
     * @param json the neural net configuration from json
     * @return {@link MultiLayerConfiguration}
     */
public static MultiLayerConfiguration fromJson(String json) {
    MultiLayerConfiguration conf;
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
    try {
        conf = mapper.readValue(json, MultiLayerConfiguration.class);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
    // Previously: enumeration used for loss functions. Now: use classes
    // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
    int layerCount = 0;
    JsonNode confs = null;
    for (NeuralNetConfiguration nnc : conf.getConfs()) {
        Layer l = nnc.getLayer();
        if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
            //lossFn field null -> may be an old config format, with lossFunction field being for the enum
            //if so, try walking the JSON graph to extract out the appropriate enum value
            BaseOutputLayer ol = (BaseOutputLayer) l;
            try {
                JsonNode jsonNode = mapper.readTree(json);
                if (confs == null) {
                    confs = jsonNode.get("confs");
                }
                if (confs instanceof ArrayNode) {
                    ArrayNode layerConfs = (ArrayNode) confs;
                    JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
                    if (outputLayerNNCNode == null)
                        //Should never happen...
                        return conf;
                    JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
                    JsonNode lossFunctionNode = null;
                    if (outputLayerNode.has("output")) {
                        lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
                    } else if (outputLayerNode.has("rnnoutput")) {
                        lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
                    }
                    if (lossFunctionNode != null) {
                        String lossFunctionEnumStr = lossFunctionNode.asText();
                        LossFunctions.LossFunction lossFunction = null;
                        try {
                            lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
                        } catch (Exception e) {
                            log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
                        }
                        if (lossFunction != null) {
                            switch(lossFunction) {
                                case MSE:
                                    ol.setLossFn(new LossMSE());
                                    break;
                                case XENT:
                                    ol.setLossFn(new LossBinaryXENT());
                                    break;
                                case NEGATIVELOGLIKELIHOOD:
                                    ol.setLossFn(new LossNegativeLogLikelihood());
                                    break;
                                case MCXENT:
                                    ol.setLossFn(new LossMCXENT());
                                    break;
                                //Remaining: TODO
                                case EXPLL:
                                case RMSE_XENT:
                                case SQUARED_LOSS:
                                case RECONSTRUCTION_CROSSENTROPY:
                                case CUSTOM:
                                default:
                                    log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", lossFunction);
                                    break;
                            }
                        }
                    }
                } else {
                    log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", (confs != null ? confs.getClass() : null));
                }
            } catch (IOException e) {
                log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
                break;
            }
        }
        //Try to load the old format if necessary, and create the appropriate IActivation instance
        if (l.getActivationFn() == null) {
            try {
                JsonNode jsonNode = mapper.readTree(json);
                if (confs == null) {
                    confs = jsonNode.get("confs");
                }
                if (confs instanceof ArrayNode) {
                    ArrayNode layerConfs = (ArrayNode) confs;
                    JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
                    if (outputLayerNNCNode == null)
                        //Should never happen...
                        return conf;
                    JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
                    if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                        continue;
                    }
                    JsonNode layerNode = layerWrapperNode.elements().next();
                    //Should only have 1 element: "dense", "output", etc
                    JsonNode activationFunction = layerNode.get("activationFunction");
                    if (activationFunction != null) {
                        IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
                        l.setActivationFn(ia);
                    }
                }
            } catch (IOException e) {
                log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
            }
        }
        layerCount++;
    }
    return conf;
}
Also used : LossNegativeLogLikelihood(org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) JsonNode(org.nd4j.shade.jackson.databind.JsonNode) IOException(java.io.IOException) IActivation(org.nd4j.linalg.activations.IActivation) IOException(java.io.IOException) LossFunctions(org.nd4j.linalg.lossfunctions.LossFunctions) LossMSE(org.nd4j.linalg.lossfunctions.impl.LossMSE) ArrayNode(org.nd4j.shade.jackson.databind.node.ArrayNode) LossBinaryXENT(org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT) ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper)

Example 2 with JsonNode

use of org.nd4j.shade.jackson.databind.JsonNode in project deeplearning4j by deeplearning4j.

the class StorageLevelDeserializer method deserialize.

@Override
public StorageLevel deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
    JsonNode node = jsonParser.getCodec().readTree(jsonParser);
    String value = node.textValue();
    if (value == null || "null".equals(value)) {
        return null;
    }
    return StorageLevel.fromString(value);
}
Also used : JsonNode(org.nd4j.shade.jackson.databind.JsonNode)

Example 3 with JsonNode

use of org.nd4j.shade.jackson.databind.JsonNode in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfiguration method fromJson.

/**
     * Create a computation graph configuration from json
     *
     * @param json the neural net configuration from json
     * @return {@link ComputationGraphConfiguration}
     */
public static ComputationGraphConfiguration fromJson(String json) {
    //As per MultiLayerConfiguration.fromJson()
    ObjectMapper mapper = NeuralNetConfiguration.mapper();
    ComputationGraphConfiguration conf;
    try {
        conf = mapper.readValue(json, ComputationGraphConfiguration.class);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    //To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
    // Previously: enumeration used for activation functions. Now: use classes
    int layerCount = 0;
    Map<String, GraphVertex> vertexMap = conf.getVertices();
    JsonNode vertices = null;
    for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) {
        if (!(entry.getValue() instanceof LayerVertex)) {
            continue;
        }
        LayerVertex lv = (LayerVertex) entry.getValue();
        if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) {
            Layer layer = lv.getLayerConf().getLayer();
            if (layer.getActivationFn() == null) {
                String layerName = layer.getLayerName();
                try {
                    if (vertices == null) {
                        JsonNode jsonNode = mapper.readTree(json);
                        vertices = jsonNode.get("vertices");
                    }
                    JsonNode vertexNode = vertices.get(layerName);
                    JsonNode layerVertexNode = vertexNode.get("LayerVertex");
                    if (layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")) {
                        continue;
                    }
                    JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
                    if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
                        continue;
                    }
                    JsonNode layerNode = layerWrapperNode.elements().next();
                    //Should only have 1 element: "dense", "output", etc
                    JsonNode activationFunction = layerNode.get("activationFunction");
                    if (activationFunction != null) {
                        IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
                        layer.setActivationFn(ia);
                    }
                } catch (IOException e) {
                    log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
                }
            }
        }
    }
    return conf;
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) JsonNode(org.nd4j.shade.jackson.databind.JsonNode) IOException(java.io.IOException) IActivation(org.nd4j.linalg.activations.IActivation) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) Layer(org.deeplearning4j.nn.conf.layers.Layer) GraphVertex(org.deeplearning4j.nn.conf.graph.GraphVertex) ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper)

Aggregations

JsonNode (org.nd4j.shade.jackson.databind.JsonNode)3 IOException (java.io.IOException)2 IActivation (org.nd4j.linalg.activations.IActivation)2 ObjectMapper (org.nd4j.shade.jackson.databind.ObjectMapper)2 GraphVertex (org.deeplearning4j.nn.conf.graph.GraphVertex)1 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)1 Layer (org.deeplearning4j.nn.conf.layers.Layer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 LossFunctions (org.nd4j.linalg.lossfunctions.LossFunctions)1 LossBinaryXENT (org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT)1 LossMCXENT (org.nd4j.linalg.lossfunctions.impl.LossMCXENT)1 LossMSE (org.nd4j.linalg.lossfunctions.impl.LossMSE)1 LossNegativeLogLikelihood (org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood)1 ArrayNode (org.nd4j.shade.jackson.databind.node.ArrayNode)1