use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyBushyBinaryOperation.
/**
* (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
* (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
*
* Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
* eagerly, which would loose additional rewrite potential. This rewrite has two goals
* (1) enable XtwXv, and increase piggybacking potential by creating bushy trees.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyBushyBinaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof BinaryOp && parent instanceof AggBinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
OpOp2 op = bop.getOp();
if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY)) {
boolean applied = false;
if (right instanceof BinaryOp) {
BinaryOp bop2 = (BinaryOp) right;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
OpOp2 op2 = bop2.getOp();
if (op == op2 && right2.getDataType() == DataType.MATRIX && (right2 instanceof AggBinaryOp)) {
// (X*(Y*op()) -> (X*Y)*op()
BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op);
BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op);
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = bop4;
applied = true;
LOG.debug("Applied simplifyBushyBinaryOperation1");
}
}
if (!applied && left instanceof BinaryOp) {
BinaryOp bop2 = (BinaryOp) left;
Hop left2 = bop2.getInput().get(0);
Hop right2 = bop2.getInput().get(1);
OpOp2 op2 = bop2.getOp();
if (op == op2 && left2.getDataType() == DataType.MATRIX && (left2 instanceof AggBinaryOp) && // X not vector, or Y vector
(right2.getDim2() > 1 || right.getDim2() == 1) && // X not vector, or Y vector
(right2.getDim1() > 1 || right.getDim1() == 1)) {
// ((op()*X)*Y) -> op()*(X*Y)
BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op);
BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op);
HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = bop4;
LOG.debug("Applied simplifyBushyBinaryOperation2");
}
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method canonicalizeMatrixMultScalarAdd.
/**
* Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and
* U%*%V-eps into the common representation U%*%V+s which simplifies
* subsequent rewrites (e.g., wdivmm or wcemm with epsilon).
*
* @param hi high-level operator
* @return high-level operator
*/
private static Hop canonicalizeMatrixMultScalarAdd(Hop hi) {
// pattern: binary operation (+ or -) of matrix mult and scalar
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// pattern: (eps + U%*%V) -> (U%*%V+eps)
if (left.getDataType().isScalar() && right instanceof AggBinaryOp && bop.getOp() == OpOp2.PLUS) {
HopRewriteUtils.removeAllChildReferences(bop);
HopRewriteUtils.addChildReference(bop, right, 0);
HopRewriteUtils.addChildReference(bop, left, 1);
LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line " + hi.getBeginLine() + ").");
} else // pattern: (U%*%V - eps) -> (U%*%V + (-eps))
if (right.getDataType().isScalar() && left instanceof AggBinaryOp && bop.getOp() == OpOp2.MINUS) {
bop.setOp(OpOp2.PLUS);
HopRewriteUtils.replaceChildReference(bop, right, HopRewriteUtils.createBinaryMinus(right), 1);
LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.
the class Explain method getNodeLabel.
private static String getNodeLabel(Hop hop) {
StringBuilder sb = new StringBuilder();
sb.append(hop.getOpString());
if (hop instanceof AggBinaryOp) {
AggBinaryOp aggBinOp = (AggBinaryOp) hop;
if (aggBinOp.getMMultMethod() != null)
sb.append(" " + aggBinOp.getMMultMethod().name() + " ");
}
// data flow properties
if (SHOW_DATA_FLOW_PROPERTIES) {
if (hop.requiresReblock() && hop.requiresCheckpoint())
sb.append(", rblk,chkpt");
else if (hop.requiresReblock())
sb.append(", rblk");
else if (hop.requiresCheckpoint())
sb.append(", chkpt");
}
if (hop.getFilename() == null) {
sb.append("[" + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + hop.getEndLine() + ":" + hop.getEndColumn() + "]");
} else {
sb.append("[" + hop.getFilename() + " " + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + hop.getEndLine() + ":" + hop.getEndColumn() + "]");
}
if (hop.getUpdateType().isInPlace())
sb.append("," + hop.getUpdateType().toString().toLowerCase());
return sb.toString();
}
use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rGetComputeCosts.
private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
if (computeCosts.containsKey(current.getHopID()))
return;
//recursively process children
for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
//get costs for given hop
double costs = 1;
if (current instanceof UnaryOp) {
switch(((UnaryOp) current).getOp()) {
case ABS:
case ROUND:
case CEIL:
case FLOOR:
case SIGN:
case SELP:
costs = 1;
break;
case SPROP:
case SQRT:
costs = 2;
break;
case EXP:
costs = 18;
break;
case SIGMOID:
costs = 21;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case NCOL:
case NROW:
case PRINT:
case CAST_AS_BOOLEAN:
case CAST_AS_DOUBLE:
case CAST_AS_INT:
case CAST_AS_MATRIX:
case CAST_AS_SCALAR:
costs = 1;
break;
case SIN:
costs = 18;
break;
case COS:
costs = 22;
break;
case TAN:
costs = 42;
break;
case ASIN:
costs = 93;
break;
case ACOS:
costs = 103;
break;
case ATAN:
costs = 40;
break;
case CUMSUM:
case CUMMIN:
case CUMMAX:
case CUMPROD:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
}
} else if (current instanceof BinaryOp) {
switch(((BinaryOp) current).getOp()) {
case MULT:
case PLUS:
case MINUS:
case MIN:
case MAX:
case AND:
case OR:
case EQUAL:
case NOTEQUAL:
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case CBIND:
case RBIND:
costs = 1;
break;
case INTDIV:
costs = 6;
break;
case MODULUS:
costs = 8;
break;
case DIV:
costs = 22;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case POW:
costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
break;
case MINUS_NZ:
case MINUS1_MULT:
costs = 2;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
//count
case 0:
costs = 1;
break;
//mean
case 1:
costs = 8;
break;
//cm2
case 2:
costs = 16;
break;
//cm3
case 3:
costs = 31;
break;
//cm4
case 4:
costs = 51;
break;
//variance
case 5:
costs = 16;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
}
} else if (current instanceof TernaryOp) {
switch(((TernaryOp) current).getOp()) {
case PLUS_MULT:
case MINUS_MULT:
costs = 2;
break;
case CTABLE:
costs = 3;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
//count
case 0:
costs = 2;
break;
//mean
case 1:
costs = 9;
break;
//cm2
case 2:
costs = 17;
break;
//cm3
case 3:
costs = 32;
break;
//cm4
case 4:
costs = 52;
break;
//variance
case 5:
costs = 17;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
}
} else if (current instanceof ParameterizedBuiltinOp) {
costs = 1;
} else if (current instanceof IndexingOp) {
costs = 1;
} else if (current instanceof ReorgOp) {
costs = 1;
} else if (current instanceof AggBinaryOp) {
//matrix vector
costs = 2;
} else if (current instanceof AggUnaryOp) {
switch(((AggUnaryOp) current).getOp()) {
case SUM:
costs = 4;
break;
case SUM_SQ:
costs = 5;
break;
case MIN:
case MAX:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
}
}
computeCosts.put(current.getHopID(), costs);
}
Aggregations