use of org.apache.sysml.lops.ConvolutionTransform.OperationTypes in project incubator-systemml by apache.
the class ConvolutionOp method constructConvolutionLops.
public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
if (inputs.size() != getNumExpectedInputs())
throw new HopsException("Incorrect number of inputs for " + op.name());
Lop in = null;
Lop in2 = null;
ArrayList<Hop> inputs1 = inputs;
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
OperationTypes lopOp = HopsConv2Lops.get(op);
// For other backends, this operators is not necessary as it reduces an additional relu operator.
if (OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING && isInputReLU(inputs.get(0))) {
in = inputs.get(0).getInput().get(0).constructLops();
lopOp = OperationTypes.RELU_MAX_POOLING;
} else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING_BACKWARD && isInputReLU(inputs.get(0))) {
in = inputs.get(0).getInput().get(0).constructLops();
lopOp = OperationTypes.RELU_MAX_POOLING_BACKWARD;
} else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && op == ConvOp.BIAS_ADD && isInputConv2d(inputs.get(0))) {
lopOp = OperationTypes.DIRECT_CONV2D_BIAS_ADD;
// the first lop is image
in = inputs.get(0).getInput().get(0).constructLops();
// the second lop is bias
in2 = inputs.get(1).constructLops();
// Use the inputs from conv2d rather than bias_add
inputs1 = inputs.get(0).getInput();
} else {
in = inputs.get(0).constructLops();
}
// // TODO: Inserting reblock requires knowing columns apriori
// ConvolutionTransform transform1 = new ConvolutionTransform(addReblockIfNecessary(et, lopOp, in), lopOp, getDataType(), getValueType(), et, k);
// setReblockedOutputDimension(et, transform1);
ConvolutionTransform transform1 = new ConvolutionTransform(in, lopOp, getDataType(), getValueType(), et, k);
setOutputDimensions(transform1);
setLineNumbers(transform1);
in.addOutput(transform1);
if (in2 != null) {
transform1.addInput(in2);
in2.addOutput(transform1);
}
// filter_shape1, filter_shape2, filter_shape3, filter_shape4
for (int i = 1; i < inputs1.size(); i++) {
Lop ltmp = inputs1.get(i).constructLops();
transform1.addInput(ltmp);
ltmp.addOutput(transform1);
}
//force order of added lops
transform1.setLevel();
return transform1;
}
Aggregations