Search in sources :

Example 1 with KerasInput

use of org.deeplearning4j.nn.modelimport.keras.layers.KerasInput in project deeplearning4j by deeplearning4j.

the class KerasModel method helperInferOutputTypes.

/**
     * Helper method called from constructor. Infers and records output type
     * for every layer.
     */
protected void helperInferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
    this.outputTypes = new HashMap<String, InputType>();
    for (KerasLayer layer : this.layersOrdered) {
        InputType outputType = null;
        if (layer instanceof KerasInput) {
            outputType = layer.getOutputType();
            /*
                 * TODO: figure out how to infer truncated BPTT value for non-sequence inputs
                 *
                 * In Keras, truncated BPTT is specified implicitly by specifying a fixed
                 * size input and by passing in the "unroll" argument to recurrent layers.
                 * Currently, the only setting in which we can confidently determine the
                 * value of truncated BPTT is if the original input has two dimensions,
                 * the first of which is sequence length. Hypothetically, we should be
                 * able to do this for other types of inputs, but that's less straightforward.
                 */
            this.truncatedBPTT = ((KerasInput) layer).getTruncatedBptt();
        } else {
            InputType[] inputTypes = new InputType[layer.getInboundLayerNames().size()];
            int i = 0;
            for (String inboundLayerName : layer.getInboundLayerNames()) inputTypes[i++] = this.outputTypes.get(inboundLayerName);
            outputType = layer.getOutputType(inputTypes);
        }
        this.outputTypes.put(layer.getLayerName(), outputType);
    }
}
Also used : InputType(org.deeplearning4j.nn.conf.inputs.InputType) KerasInput(org.deeplearning4j.nn.modelimport.keras.layers.KerasInput)

Aggregations

InputType (org.deeplearning4j.nn.conf.inputs.InputType)1 KerasInput (org.deeplearning4j.nn.modelimport.keras.layers.KerasInput)1