use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rGetComputeCosts.
private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
if (computeCosts.containsKey(current.getHopID()))
return;
// recursively process children
for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
// get costs for given hop
double costs = 1;
if (current instanceof UnaryOp) {
switch(((UnaryOp) current).getOp()) {
case ABS:
case ROUND:
case CEIL:
case FLOOR:
case SIGN:
costs = 1;
break;
case SPROP:
case SQRT:
costs = 2;
break;
case EXP:
costs = 18;
break;
case SIGMOID:
costs = 21;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case NCOL:
case NROW:
case PRINT:
case ASSERT:
case CAST_AS_BOOLEAN:
case CAST_AS_DOUBLE:
case CAST_AS_INT:
case CAST_AS_MATRIX:
case CAST_AS_SCALAR:
costs = 1;
break;
case SIN:
costs = 18;
break;
case COS:
costs = 22;
break;
case TAN:
costs = 42;
break;
case ASIN:
costs = 93;
break;
case ACOS:
costs = 103;
break;
case ATAN:
costs = 40;
break;
// TODO:
case SINH:
costs = 93;
break;
case COSH:
costs = 103;
break;
case TANH:
costs = 40;
break;
case CUMSUM:
case CUMMIN:
case CUMMAX:
case CUMPROD:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
}
} else if (current instanceof BinaryOp) {
switch(((BinaryOp) current).getOp()) {
case MULT:
case PLUS:
case MINUS:
case MIN:
case MAX:
case AND:
case OR:
case EQUAL:
case NOTEQUAL:
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case CBIND:
case RBIND:
costs = 1;
break;
case INTDIV:
costs = 6;
break;
case MODULUS:
costs = 8;
break;
case DIV:
costs = 22;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case POW:
costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
break;
case MINUS_NZ:
case MINUS1_MULT:
costs = 2;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
// count
case 0:
costs = 1;
break;
// mean
case 1:
costs = 8;
break;
// cm2
case 2:
costs = 16;
break;
// cm3
case 3:
costs = 31;
break;
// cm4
case 4:
costs = 51;
break;
// variance
case 5:
costs = 16;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
}
} else if (current instanceof TernaryOp) {
switch(((TernaryOp) current).getOp()) {
case PLUS_MULT:
case MINUS_MULT:
costs = 2;
break;
case CTABLE:
costs = 3;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
// count
case 0:
costs = 2;
break;
// mean
case 1:
costs = 9;
break;
// cm2
case 2:
costs = 17;
break;
// cm3
case 3:
costs = 32;
break;
// cm4
case 4:
costs = 52;
break;
// variance
case 5:
costs = 17;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
}
} else if (current instanceof ParameterizedBuiltinOp) {
costs = 1;
} else if (current instanceof IndexingOp) {
costs = 1;
} else if (current instanceof ReorgOp) {
costs = 1;
} else if (current instanceof AggBinaryOp) {
// matrix vector
costs = 2;
} else if (current instanceof AggUnaryOp) {
switch(((AggUnaryOp) current).getOp()) {
case SUM:
costs = 4;
break;
case SUM_SQ:
costs = 5;
break;
case MIN:
case MAX:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
}
}
computeCosts.put(current.getHopID(), costs);
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class TemplateRow method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
// recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof AggUnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
if (hop.getInput().get(0).getDim2() == 1)
out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
else {
String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
}
} else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
// vector add without temporary copy
if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
else
out = cdata1;
} else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
}
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
// correct input under transpose
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
inHops.remove(hop.getInput().get(0));
if (cdata1 instanceof CNodeData)
inHops.add(hop.getInput().get(0).getInput().get(0));
// note: vectorMultAdd applicable to vector-scalar, and vector-vector
if (hop.getInput().get(1).getDim2() == 1)
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
if (!inHops2.containsKey("B1")) {
// incl modification of X for consistency
if (cdata1 instanceof CNodeData)
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
if (!inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
} else {
if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
else if (hop.getInput().get(1).getDim2() == 1) {
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
inHops2.put("X", hop.getInput().get(0));
} else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
inHops2.put("X", hop.getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
} else
throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
} else // general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
} else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
// special case for cbind with zeros
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = null;
if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
// rm 0-matrix
inHops.remove(hop.getInput().get(1));
} else {
cdata2 = tmp.get(hop.getInput().get(1).getHopID());
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
}
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
(hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
} else {
String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
}
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
inHops2.put("X", hop.getInput().get(0));
}
} else
throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
} else // one input is a vector/scalar other is a scalar
{
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if (// vector or vector can be inferred from lhs
TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
}
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
} else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
CNode[] inputs = new CNode[hop.getInput().size()];
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
CNode cdata = tmp.get(c.getHopID());
if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
inputs[i] = cdata;
if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", c);
}
out = new CNodeNary(inputs, NaryType.VECT_CBIND);
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
}
if (out == null) {
throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
}
if (out.getDataType().isMatrix()) {
out.setNumRows(hop.getDim1());
out.setNumCols(hop.getDim2());
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class DMLTranslator method processBuiltinFunctionExpression.
/**
* Construct Hops from parse tree : Process BuiltinFunction Expression in an
* assignment statement
*
* @param source built-in function expression
* @param target data identifier
* @param hops map of high-level operators
* @return high-level operator
*/
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
Hop expr = processExpression(source.getFirstExpr(), null, hops);
Hop expr2 = null;
if (source.getSecondExpr() != null) {
expr2 = processExpression(source.getSecondExpr(), null, hops);
}
Hop expr3 = null;
if (source.getThirdExpr() != null) {
expr3 = processExpression(source.getThirdExpr(), null, hops);
}
Hop currBuiltinOp = null;
if (target == null) {
target = createTarget(source);
}
// Construct the hop based on the type of Builtin function
switch(source.getOpCode()) {
case EVAL:
currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
break;
case COLSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Col, expr);
break;
case COLMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Col, expr);
break;
case COLMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Col, expr);
break;
case COLMEAN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Col, expr);
break;
case COLSD:
// colStdDevs = sqrt(colVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case COLVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
break;
case ROWSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Row, expr);
break;
case ROWMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Row, expr);
break;
case ROWINDEXMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX, Direction.Row, expr);
break;
case ROWINDEXMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX, Direction.Row, expr);
break;
case ROWMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Row, expr);
break;
case ROWMEAN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Row, expr);
break;
case ROWSD:
// rowStdDevs = sqrt(rowVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case ROWVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
break;
case NROW:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
long nRows = expr.getDim1();
if (nRows == -1) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr);
} else {
currBuiltinOp = new LiteralOp(nRows);
}
break;
case NCOL:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
long nCols = expr.getDim2();
if (nCols == -1) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NCOL, expr);
} else {
currBuiltinOp = new LiteralOp(nCols);
}
break;
case LENGTH:
long nRows2 = expr.getDim1();
long nCols2 = expr.getDim2();
/*
* If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
* Else create a UnaryOp so that a control program instruction is generated
*/
if ((nCols2 == -1) || (nRows2 == -1)) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr);
} else {
long lval = (nCols2 * nRows2);
currBuiltinOp = new LiteralOp(lval);
}
break;
case SUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.RowCol, expr);
break;
case MEAN:
if (expr2 == null) {
// example: x = mean(Y);
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.RowCol, expr);
} else {
// example: x = mean(Y,W);
// stable weighted mean is implemented by using centralMoment with order = 0
Hop orderHop = new LiteralOp(0);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop);
}
break;
case SD:
// stdDev = sqrt(variance)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case VAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
break;
case MIN:
// construct AggUnary for min(X) but BinaryOp for min(X,Y)
if (expr2 == null) {
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.RowCol, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MIN, expr, expr2);
}
break;
case MAX:
// construct AggUnary for max(X) but BinaryOp for max(X,Y)
if (expr2 == null) {
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.RowCol, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MAX, expr, expr2);
}
break;
case PPRED:
String sop = ((StringIdentifier) source.getThirdExpr()).getValue();
sop = sop.replace("\"", "");
OpOp2 operation;
if (sop.equalsIgnoreCase(">="))
operation = OpOp2.GREATEREQUAL;
else if (sop.equalsIgnoreCase(">"))
operation = OpOp2.GREATER;
else if (sop.equalsIgnoreCase("<="))
operation = OpOp2.LESSEQUAL;
else if (sop.equalsIgnoreCase("<"))
operation = OpOp2.LESS;
else if (sop.equalsIgnoreCase("=="))
operation = OpOp2.EQUAL;
else if (sop.equalsIgnoreCase("!="))
operation = OpOp2.NOTEQUAL;
else {
LOG.error(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
break;
case PROD:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, Direction.RowCol, expr);
break;
case TRACE:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE, Direction.RowCol, expr);
break;
case TRANS:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.TRANSPOSE, expr);
break;
case REV:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.REV, expr);
break;
case CBIND:
case RBIND:
OpOp2 appendOp1 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
OpOpN appendOp2 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOpN.CBIND : OpOpN.RBIND;
currBuiltinOp = (source.getAllExpr().length == 2) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp1, expr, expr2) : new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp2, processAllExpressions(source.getAllExpr(), hops));
break;
case DIAG:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.DIAG, expr);
break;
case TABLE:
// Always a TertiaryOp is created for table().
// - create a hop for weights, if not provided in the function call.
int numTableArgs = source._args.length;
switch(numTableArgs) {
case 2:
case 4:
// example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
// here, weight is interpreted as 1.0
Hop weightHop = new LiteralOp(1.0);
// set dimensions
weightHop.setDim1(0);
weightHop.setDim2(0);
weightHop.setNnz(-1);
weightHop.setRowsInBlock(0);
weightHop.setColsInBlock(0);
if (numTableArgs == 2)
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
else {
Hop outDim1 = processExpression(source._args[2], null, hops);
Hop outDim2 = processExpression(source._args[3], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2);
}
break;
case 3:
case 5:
// example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15)
if (numTableArgs == 3)
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
else {
Hop outDim1 = processExpression(source._args[3], null, hops);
Hop outDim2 = processExpression(source._args[4], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2);
}
break;
default:
throw new ParseException("Invalid number of arguments " + numTableArgs + " to table() function.");
}
break;
// data type casts
case CAST_AS_SCALAR:
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
break;
case CAST_AS_MATRIX:
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
break;
case CAST_AS_FRAME:
currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
break;
// value type casts
case CAST_AS_DOUBLE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.DOUBLE, Hop.OpOp1.CAST_AS_DOUBLE, expr);
break;
case CAST_AS_INT:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT, Hop.OpOp1.CAST_AS_INT, expr);
break;
case CAST_AS_BOOLEAN:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
break;
// Boolean binary
case XOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.XOR, expr, expr2);
break;
case BITWAND:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWAND, expr, expr2);
break;
case BITWOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWOR, expr, expr2);
break;
case BITWXOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWXOR, expr, expr2);
break;
case BITWSHIFTL:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTL, expr, expr2);
break;
case BITWSHIFTR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTR, expr, expr2);
break;
case ABS:
case SIN:
case COS:
case TAN:
case ASIN:
case ACOS:
case ATAN:
case SINH:
case COSH:
case TANH:
case SIGN:
case SQRT:
case EXP:
case ROUND:
case CEIL:
case FLOOR:
case CUMSUM:
case CUMPROD:
case CUMMIN:
case CUMMAX:
Hop.OpOp1 mathOp1;
switch(source.getOpCode()) {
case ABS:
mathOp1 = Hop.OpOp1.ABS;
break;
case SIN:
mathOp1 = Hop.OpOp1.SIN;
break;
case COS:
mathOp1 = Hop.OpOp1.COS;
break;
case TAN:
mathOp1 = Hop.OpOp1.TAN;
break;
case ASIN:
mathOp1 = Hop.OpOp1.ASIN;
break;
case ACOS:
mathOp1 = Hop.OpOp1.ACOS;
break;
case ATAN:
mathOp1 = Hop.OpOp1.ATAN;
break;
case SINH:
mathOp1 = Hop.OpOp1.SINH;
break;
case COSH:
mathOp1 = Hop.OpOp1.COSH;
break;
case TANH:
mathOp1 = Hop.OpOp1.TANH;
break;
case SIGN:
mathOp1 = Hop.OpOp1.SIGN;
break;
case SQRT:
mathOp1 = Hop.OpOp1.SQRT;
break;
case EXP:
mathOp1 = Hop.OpOp1.EXP;
break;
case ROUND:
mathOp1 = Hop.OpOp1.ROUND;
break;
case CEIL:
mathOp1 = Hop.OpOp1.CEIL;
break;
case FLOOR:
mathOp1 = Hop.OpOp1.FLOOR;
break;
case CUMSUM:
mathOp1 = Hop.OpOp1.CUMSUM;
break;
case CUMPROD:
mathOp1 = Hop.OpOp1.CUMPROD;
break;
case CUMMIN:
mathOp1 = Hop.OpOp1.CUMMIN;
break;
case CUMMAX:
mathOp1 = Hop.OpOp1.CUMMAX;
break;
default:
LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp1, expr);
break;
case LOG:
if (expr2 == null) {
Hop.OpOp1 mathOp2;
switch(source.getOpCode()) {
case LOG:
mathOp2 = Hop.OpOp1.LOG;
break;
default:
LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp2, expr);
} else {
Hop.OpOp2 mathOp3;
switch(source.getOpCode()) {
case LOG:
mathOp3 = Hop.OpOp2.LOG;
break;
default:
LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3, expr, expr2);
}
break;
case MOMENT:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.CENTRALMOMENT, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, expr3);
}
break;
case COV:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.COVARIANCE, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.COVARIANCE, expr, expr2, expr3);
}
break;
case QUANTILE:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.QUANTILE, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.QUANTILE, expr, expr2, expr3);
}
break;
case INTERQUANTILE:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.INTERQUANTILE, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.INTERQUANTILE, expr, expr2, expr3);
}
break;
case IQM:
if (expr2 == null) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.IQM, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.IQM, expr, expr2);
}
break;
case MEDIAN:
if (expr2 == null) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.MEDIAN, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.MEDIAN, expr, expr2);
}
break;
case IFELSE:
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.IFELSE, expr, expr2, expr3);
break;
case SEQ:
HashMap<String, Hop> randParams = new HashMap<>();
randParams.put(Statement.SEQ_FROM, expr);
randParams.put(Statement.SEQ_TO, expr2);
randParams.put(Statement.SEQ_INCR, (expr3 != null) ? expr3 : new LiteralOp(1));
// note incr: default -1 (for from>to) handled during runtime
currBuiltinOp = new DataGenOp(DataGenMethod.SEQ, target, randParams);
break;
case SAMPLE:
{
Expression[] in = source.getAllExpr();
// arguments: range/size/replace/seed; defaults: replace=FALSE
HashMap<String, Hop> tmpparams = new HashMap<>();
// range
tmpparams.put(DataExpression.RAND_MAX, expr);
tmpparams.put(DataExpression.RAND_ROWS, expr2);
tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));
if (in.length == 4) {
tmpparams.put(DataExpression.RAND_PDF, expr3);
Hop seed = processExpression(in[3], null, hops);
tmpparams.put(DataExpression.RAND_SEED, seed);
} else if (in.length == 3) {
// check if the third argument is "replace" or "seed"
if (expr3.getValueType() == ValueType.BOOLEAN) {
tmpparams.put(DataExpression.RAND_PDF, expr3);
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
} else if (expr3.getValueType() == ValueType.INT) {
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, expr3);
} else
throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");
} else if (in.length == 2) {
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
}
currBuiltinOp = new DataGenOp(DataGenMethod.SAMPLE, target, tmpparams);
break;
}
case SOLVE:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2);
break;
case INVERSE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.INVERSE, expr);
break;
case CHOLESKY:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.CHOLESKY, expr);
break;
case OUTER:
if (!(expr3 instanceof LiteralOp))
throw new HopsException("Operator for outer builtin function must be a constant: " + expr3);
OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp) expr3).getStringValue());
if (op == null)
throw new HopsException("Unsupported outer vector binary operation: " + ((LiteralOp) expr3).getStringValue());
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2);
// flag op as specific outer vector operation
((BinaryOp) currBuiltinOp).setOuterVectorOperation(true);
// force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
currBuiltinOp.refreshSizeInformation();
break;
case CONV2D:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case BIAS_ADD:
{
ArrayList<Hop> inHops1 = new ArrayList<>();
inHops1.add(expr);
inHops1.add(expr2);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_ADD, inHops1);
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case BIAS_MULTIPLY:
{
ArrayList<Hop> inHops1 = new ArrayList<>();
inHops1.add(expr);
inHops1.add(expr2);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_MULTIPLY, inHops1);
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case AVG_POOL:
case MAX_POOL:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops);
if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL)
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1);
else
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case AVG_POOL_BACKWARD:
case MAX_POOL_BACKWARD:
{
Hop image = expr;
// process dout as well
ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops);
if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD)
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
else
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING_BACKWARD, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case CONV2D_BACKWARD_FILTER:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case CONV2D_BACKWARD_DATA:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
default:
throw new ParseException("Unsupported builtin function type: " + source.getOpCode());
}
boolean isConvolution = source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == BuiltinFunctionOp.MAX_POOL || source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || source.getOpCode() == BuiltinFunctionOp.AVG_POOL || source.getOpCode() == BuiltinFunctionOp.AVG_POOL_BACKWARD;
if (!isConvolution) {
// Since the dimension of output doesnot match that of input variable for these operations
setIdentifierParams(currBuiltinOp, source.getOutput());
}
currBuiltinOp.setParseInfo(source);
return currBuiltinOp;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class SpoofCompiler method rConstructModifiedHopDag.
private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans, HashMap<Long, Pair<Hop[], Class<?>>> clas, HashSet<Long> memo) {
if (memo.contains(hop.getHopID()))
// already processed
return;
Hop hnew = hop;
if (clas.containsKey(hop.getHopID())) {
// replace sub-dag with generated operator
Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID());
CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue();
hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), tmpCla.getValue(), false, tmpCNode.getOutputDimType());
Hop[] inHops = tmpCla.getKey();
for (int i = 0; i < inHops.length; i++) {
if (tmpCNode instanceof CNodeOuterProduct && inHops[i].getHopID() == ((CNodeData) tmpCNode.getInput().get(2)).getHopID() && !TemplateUtils.hasTransposeParentUnderOuterProduct(inHops[i])) {
hnew.addInput(HopRewriteUtils.createTranspose(inHops[i]));
} else
// add inputs
hnew.addInput(inHops[i]);
}
// modify output parameters
HopRewriteUtils.setOutputParameters(hnew, hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
if (tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct) tmpCNode).isTransposeOutput())
hnew = HopRewriteUtils.createTranspose(hnew);
else if (tmpCNode instanceof CNodeMultiAgg) {
ArrayList<Hop> roots = ((CNodeMultiAgg) tmpCNode).getRootNodes();
hnew.setDataType(DataType.MATRIX);
HopRewriteUtils.setOutputParameters(hnew, 1, roots.size(), inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1);
// inject artificial right indexing operations for all parents of all nodes
for (int i = 0; i < roots.size(); i++) {
Hop hnewi = (roots.get(i) instanceof AggUnaryOp) ? HopRewriteUtils.createScalarIndexing(hnew, 1, i + 1) : HopRewriteUtils.createIndexingOp(hnew, 1, i + 1);
HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
}
} else if (tmpCNode instanceof CNodeCell && ((CNodeCell) tmpCNode).requiredCastDtm()) {
HopRewriteUtils.setOutputParametersForScalar(hnew);
hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
} else if (tmpCNode instanceof CNodeRow && (((CNodeRow) tmpCNode).getRowType() == RowType.NO_AGG_CONST || ((CNodeRow) tmpCNode).getRowType() == RowType.COL_AGG_CONST))
((SpoofFusedOp) hnew).setConstDim2(((CNodeRow) tmpCNode).getConstDim2());
if (!(tmpCNode instanceof CNodeMultiAgg))
HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
memo.add(hnew.getHopID());
}
// process hops recursively (parent-child links modified)
for (int i = 0; i < hnew.getInput().size(); i++) {
Hop c = hnew.getInput().get(i);
rConstructModifiedHopDag(c, cplans, clas, memo);
}
memo.add(hnew.getHopID());
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyUnaryAggReorgOperation.
private static Hop simplifyUnaryAggReorgOperation(Hop parent, Hop hi, int pos) {
if (// full uagg
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // reorg operation
hi.getInput().get(0) instanceof ReorgOp) {
ReorgOp rop = (ReorgOp) hi.getInput().get(0);
if ((rop.getOp() == ReOrgOp.TRANSPOSE || rop.getOp() == ReOrgOp.RESHAPE || // valid reorg
rop.getOp() == ReOrgOp.REV) && // uagg only reorg consumer
rop.getParent().size() == 1) {
Hop input = rop.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(rop);
HopRewriteUtils.addChildReference(hi, input);
LOG.debug("Applied simplifyUnaryAggReorgOperation");
}
}
return hi;
}
Aggregations