use of org.deeplearning4j.nn.modelimport.keras.layers.KerasInput in project deeplearning4j by deeplearning4j.
the class KerasModel method helperInferOutputTypes.
/**
* Helper method called from constructor. Infers and records output type
* for every layer.
*/
protected void helperInferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this.outputTypes = new HashMap<String, InputType>();
for (KerasLayer layer : this.layersOrdered) {
InputType outputType = null;
if (layer instanceof KerasInput) {
outputType = layer.getOutputType();
/*
* TODO: figure out how to infer truncated BPTT value for non-sequence inputs
*
* In Keras, truncated BPTT is specified implicitly by specifying a fixed
* size input and by passing in the "unroll" argument to recurrent layers.
* Currently, the only setting in which we can confidently determine the
* value of truncated BPTT is if the original input has two dimensions,
* the first of which is sequence length. Hypothetically, we should be
* able to do this for other types of inputs, but that's less straightforward.
*/
this.truncatedBPTT = ((KerasInput) layer).getTruncatedBptt();
} else {
InputType[] inputTypes = new InputType[layer.getInboundLayerNames().size()];
int i = 0;
for (String inboundLayerName : layer.getInboundLayerNames()) inputTypes[i++] = this.outputTypes.get(inboundLayerName);
outputType = layer.getOutputType(inputTypes);
}
this.outputTypes.put(layer.getLayerName(), outputType);
}
}
Aggregations