use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class IPAPassRemoveConstantBinaryOps method rRemoveConstantBinaryOp.
private static void rRemoveConstantBinaryOp(Hop hop, HashMap<String, Hop> mOnes) {
if (hop.isVisited())
return;
if (hop instanceof BinaryOp && ((BinaryOp) hop).getOp() == OpOp2.MULT && !((BinaryOp) hop).isOuterVectorOperator() && hop.getInput().get(0).getDataType() == DataType.MATRIX && hop.getInput().get(1) instanceof DataOp && mOnes.containsKey(hop.getInput().get(1).getName())) {
// replace matrix of ones with literal 1 (later on removed by
// algebraic simplification rewrites; otherwise more complex
// recursive processing of childs and rewiring required)
HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1);
HopRewriteUtils.addChildReference(hop, new LiteralOp(1), 1);
}
// recursively process child nodes
for (Hop c : hop.getInput()) rRemoveConstantBinaryOp(c, mOnes);
hop.setVisited();
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class HopRewriteUtils method isNotMatrixVectorBinaryOperation.
public static boolean isNotMatrixVectorBinaryOperation(Hop hop) {
boolean ret = true;
if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
boolean mv = (left.getDim1() > 1 && right.getDim1() == 1) || (left.getDim2() > 1 && right.getDim2() == 1);
ret = isDimsKnown(bop) && !mv;
}
return ret;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class HopRewriteUtils method isBinarySparseSafe.
public static boolean isBinarySparseSafe(Hop hop) {
if (!(hop instanceof BinaryOp))
return false;
if (isBinary(hop, OpOp2.MULT, OpOp2.DIV))
return true;
BinaryOp bop = (BinaryOp) hop;
Hop lit = bop.getInput().get(0) instanceof LiteralOp ? bop.getInput().get(0) : bop.getInput().get(1) instanceof LiteralOp ? bop.getInput().get(1) : null;
return lit != null && OptimizerUtils.isBinaryOpSparsityConditionalSparseSafe(bop.getOp(), (LiteralOp) lit);
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.
/**
* Searches for weighted squared loss expressions and replaces them with a quaternary operator.
* Currently, this search includes the following three patterns:
* 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
* 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
* 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
*
* NOTE: We include transpose into the pattern because during runtime we need to compute
* U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
* without additional memory requirements for internal transpose.
*
* This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
* U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
* rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
* as we generalized the runtime or hop/lop compilation.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
// NOTE: there might be also a general simplification without custom operator
// via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
Hop hnew = null;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
boolean appliedPattern = false;
// alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
Hop W = bop.getInput().get(0);
// (X - U %*% t(V))
Hop tmp = bop.getInput().get(1).getInput().get(0);
if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
// a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (W * (U %*% t(V) - X) ^ 2)
if (// ba gurantees matrices
tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
// handle special case of post_nz
if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
W = new LiteralOp(1);
}
// construct quaternary hop
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 1;
} else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 0;
}
if (// rewrite match
wuvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
// (W * (U %*% t(V)))
Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
Hop W = tmp.getInput().get(0);
Hop U = tmp.getInput().get(1).getInput().get(0);
Hop V = tmp.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (((U %*% t(V)) - X) ^ 2)
if (// ba gurantees matrices
lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
// (U %*% t(V))
Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
// no weighting
Hop W = new LiteralOp(1);
Hop U = tmp.getInput().get(0);
Hop V = tmp.getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method fuseAxpyBinaryOperationChain.
private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) {
// patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
if (hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp) hi).getOp() == OpOp2.PLUS || ((BinaryOp) hi).getOp() == OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
Hop ternop = null;
// pattern (a) X + s*Y -> X +* sY
if (bop.getOp() == OpOp2.PLUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
right.getParent().size() == 1) {
Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " + hi.getBeginLine() + ")");
} else // pattern (b) s*Y + X -> X +* sY
if (bop.getOp() == OpOp2.PLUS && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
left.getParent().size() == 1) {
Hop smid = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " + hi.getBeginLine() + ")");
} else // pattern (c) X - s*Y -> X -* sY
if (bop.getOp() == OpOp2.MINUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
right.getParent().size() == 1) {
Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " + hi.getBeginLine() + ")");
}
// rewire parent-child operators if rewrite applied
if (ternop != null) {
HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
hi = ternop;
}
}
return hi;
}
Aggregations