Search in sources :

Example 1 with InvalidInputTypeException

use of org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException in project deeplearning4j by deeplearning4j.

the class ElementWiseVertex method getOutputType.

@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
    if (vertexInputs.length == 1)
        return vertexInputs[0];
    InputType first = vertexInputs[0];
    if (first.getType() != InputType.Type.CNN) {
        //FF, RNN or flat CNN data inputs
        for (int i = 1; i < vertexInputs.length; i++) {
            if (vertexInputs[i].getType() != first.getType()) {
                throw new InvalidInputTypeException("Invalid input: ElementWise vertex cannot process activations of different types:" + " first type = " + first.getType() + ", input type " + (i + 1) + " = " + vertexInputs[i].getType());
            }
        }
    } else {
        //CNN inputs... also check that the depth, width and heights match:
        InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
        int fd = firstConv.getDepth();
        int fw = firstConv.getWidth();
        int fh = firstConv.getHeight();
        for (int i = 1; i < vertexInputs.length; i++) {
            if (vertexInputs[i].getType() != InputType.Type.CNN) {
                throw new InvalidInputTypeException("Invalid input: ElementWise vertex cannot process activations of different types:" + " first type = " + InputType.Type.CNN + ", input type " + (i + 1) + " = " + vertexInputs[i].getType());
            }
            InputType.InputTypeConvolutional otherConv = (InputType.InputTypeConvolutional) vertexInputs[i];
            int od = otherConv.getDepth();
            int ow = otherConv.getWidth();
            int oh = otherConv.getHeight();
            if (fd != od || fw != ow || fh != oh) {
                throw new InvalidInputTypeException("Invalid input: ElementWise vertex cannot process CNN activations of different sizes:" + "first [depth,width,height] = [" + fd + "," + fw + "," + fh + "], input " + i + " = [" + od + "," + ow + "," + oh + "]");
            }
        }
    }
    //Same output shape/size as
    return first;
}
Also used : InputType(org.deeplearning4j.nn.conf.inputs.InputType) InvalidInputTypeException(org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException)

Aggregations

InputType (org.deeplearning4j.nn.conf.inputs.InputType)1 InvalidInputTypeException (org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException)1