use of org.nd4j.shade.jackson.databind.JsonNode in project deeplearning4j by deeplearning4j.
the class MultiLayerConfiguration method fromJson.
/**
* Create a neural net configuration from json
* @param json the neural net configuration from json
* @return {@link MultiLayerConfiguration}
*/
public static MultiLayerConfiguration fromJson(String json) {
MultiLayerConfiguration conf;
ObjectMapper mapper = NeuralNetConfiguration.mapper();
try {
conf = mapper.readValue(json, MultiLayerConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
// Previously: enumeration used for loss functions. Now: use classes
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
int layerCount = 0;
JsonNode confs = null;
for (NeuralNetConfiguration nnc : conf.getConfs()) {
Layer l = nnc.getLayer();
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) {
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
//if so, try walking the JSON graph to extract out the appropriate enum value
BaseOutputLayer ol = (BaseOutputLayer) l;
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
//Should never happen...
return conf;
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
JsonNode lossFunctionNode = null;
if (outputLayerNode.has("output")) {
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
} else if (outputLayerNode.has("rnnoutput")) {
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
}
if (lossFunctionNode != null) {
String lossFunctionEnumStr = lossFunctionNode.asText();
LossFunctions.LossFunction lossFunction = null;
try {
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
} catch (Exception e) {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
}
if (lossFunction != null) {
switch(lossFunction) {
case MSE:
ol.setLossFn(new LossMSE());
break;
case XENT:
ol.setLossFn(new LossBinaryXENT());
break;
case NEGATIVELOGLIKELIHOOD:
ol.setLossFn(new LossNegativeLogLikelihood());
break;
case MCXENT:
ol.setLossFn(new LossMCXENT());
break;
//Remaining: TODO
case EXPLL:
case RMSE_XENT:
case SQUARED_LOSS:
case RECONSTRUCTION_CROSSENTROPY:
case CUSTOM:
default:
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", lossFunction);
break;
}
}
}
} else {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", (confs != null ? confs.getClass() : null));
}
} catch (IOException e) {
log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", e);
break;
}
}
//Try to load the old format if necessary, and create the appropriate IActivation instance
if (l.getActivationFn() == null) {
try {
JsonNode jsonNode = mapper.readTree(json);
if (confs == null) {
confs = jsonNode.get("confs");
}
if (confs instanceof ArrayNode) {
ArrayNode layerConfs = (ArrayNode) confs;
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
if (outputLayerNNCNode == null)
//Should never happen...
return conf;
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
//Should only have 1 element: "dense", "output", etc
JsonNode activationFunction = layerNode.get("activationFunction");
if (activationFunction != null) {
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
l.setActivationFn(ia);
}
}
} catch (IOException e) {
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
}
}
layerCount++;
}
return conf;
}
use of org.nd4j.shade.jackson.databind.JsonNode in project deeplearning4j by deeplearning4j.
the class StorageLevelDeserializer method deserialize.
@Override
public StorageLevel deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser);
String value = node.textValue();
if (value == null || "null".equals(value)) {
return null;
}
return StorageLevel.fromString(value);
}
use of org.nd4j.shade.jackson.databind.JsonNode in project deeplearning4j by deeplearning4j.
the class ComputationGraphConfiguration method fromJson.
/**
* Create a computation graph configuration from json
*
* @param json the neural net configuration from json
* @return {@link ComputationGraphConfiguration}
*/
public static ComputationGraphConfiguration fromJson(String json) {
//As per MultiLayerConfiguration.fromJson()
ObjectMapper mapper = NeuralNetConfiguration.mapper();
ComputationGraphConfiguration conf;
try {
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
} catch (IOException e) {
throw new RuntimeException(e);
}
//To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier)
// Previously: enumeration used for activation functions. Now: use classes
int layerCount = 0;
Map<String, GraphVertex> vertexMap = conf.getVertices();
JsonNode vertices = null;
for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) {
if (!(entry.getValue() instanceof LayerVertex)) {
continue;
}
LayerVertex lv = (LayerVertex) entry.getValue();
if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) {
Layer layer = lv.getLayerConf().getLayer();
if (layer.getActivationFn() == null) {
String layerName = layer.getLayerName();
try {
if (vertices == null) {
JsonNode jsonNode = mapper.readTree(json);
vertices = jsonNode.get("vertices");
}
JsonNode vertexNode = vertices.get(layerName);
JsonNode layerVertexNode = vertexNode.get("LayerVertex");
if (layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")) {
continue;
}
JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer");
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
continue;
}
JsonNode layerNode = layerWrapperNode.elements().next();
//Should only have 1 element: "dense", "output", etc
JsonNode activationFunction = layerNode.get("activationFunction");
if (activationFunction != null) {
IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction();
layer.setActivationFn(ia);
}
} catch (IOException e) {
log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e);
}
}
}
}
return conf;
}
Aggregations