use of org.apache.sysml.hops.QuaternaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.
/**
* Searches for weighted squared loss expressions and replaces them with a quaternary operator.
* Currently, this search includes the following three patterns:
* 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
* 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
* 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
*
* NOTE: We include transpose into the pattern because during runtime we need to compute
* U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
* without additional memory requirements for internal transpose.
*
* This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
* U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
* rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
* as we generalized the runtime or hop/lop compilation.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
// NOTE: there might be also a general simplification without custom operator
// via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
Hop hnew = null;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
boolean appliedPattern = false;
// alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
Hop W = bop.getInput().get(0);
// (X - U %*% t(V))
Hop tmp = bop.getInput().get(1).getInput().get(0);
if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
// a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (W * (U %*% t(V) - X) ^ 2)
if (// ba gurantees matrices
tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
// handle special case of post_nz
if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
W = new LiteralOp(1);
}
// construct quaternary hop
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 1;
} else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 0;
}
if (// rewrite match
wuvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
// (W * (U %*% t(V)))
Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
Hop W = tmp.getInput().get(0);
Hop U = tmp.getInput().get(1).getInput().get(0);
Hop V = tmp.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (((U %*% t(V)) - X) ^ 2)
if (// ba gurantees matrices
lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
// (U %*% t(V))
Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
// no weighting
Hop W = new LiteralOp(1);
Hop U = tmp.getInput().get(0);
Hop V = tmp.getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.QuaternaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.
private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
Hop hnew = null;
if (// all patterns subrooted by W *
HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
hi.getDim2() > 1 && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
hi.getInput().get(1) instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hi.getInput().get(1);
boolean appliedPattern = false;
// Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
}
// Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
}
// Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = bop.getInput().get(1).getInput().get(0);
Hop tX = bop.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.QuaternaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedDivMM.
private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
// note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
if (HopRewriteUtils.isMatrixMultiply(hi) && (hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) || hi.getInput().get(1) instanceof BinaryOp && // not applied for vector-vector mult
hi.getDim2() > 1 && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY))) {
Hop left = hi.getInput().get(0);
Hop right = hi.getInput().get(1);
// alternative pattern: t(U) %*% (W*(U%*%t(V)))
if (right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) right).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0), true)) {
Hop W = right.getInput().get(0);
Hop U = right.getInput().get(1).getInput().get(0);
Hop V = right.getInput().get(1).getInput().get(1);
if (HopRewriteUtils.isTransposeOfItself(left, U)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
boolean mult = ((BinaryOp) right).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
// add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM1 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
if (!appliedPattern && // DIV
HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = right.getInput().get(0);
Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
Hop X = right.getInput().get(1).getInput().get(1);
if (HopRewriteUtils.isTransposeOfItself(left, U)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 3, false, // 3=>DIV_LEFT_EPS
false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
// add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM1e (line " + hi.getBeginLine() + ")");
}
}
// alternative pattern: (W*(U%*%t(V))) %*% V
if (!appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0), true)) {
Hop W = left.getInput().get(0);
Hop U = left.getInput().get(1).getInput().get(0);
Hop V = left.getInput().get(1).getInput().get(1);
if (HopRewriteUtils.isTransposeOfItself(right, V)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = right;
else
V = V.getInput().get(0);
boolean mult = ((BinaryOp) left).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM2 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 2e) (W/(U%*%t(V) + x)) %*% V
if (!appliedPattern && // DIV
HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = left.getInput().get(0);
Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
Hop X = left.getInput().get(1).getInput().get(1);
if (HopRewriteUtils.isTransposeOfItself(right, V)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = right;
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 4, false, // 4=>DIV_RIGHT_EPS
false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM2e (line " + hi.getBeginLine() + ")");
}
}
// Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
if (!appliedPattern && // MULT
HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = right.getInput().get(0);
Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
Hop X = right.getInput().get(1).getInput().get(1);
if (// W-X constraint
HopRewriteUtils.isNonZeroIndicator(W, X) && // t(U)-U constraint
HopRewriteUtils.isTransposeOfItself(left, U)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
// add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM3 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
if (!appliedPattern && // MULT
HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = left.getInput().get(0);
Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
Hop X = left.getInput().get(1).getInput().get(1);
if (// W-X constraint
HopRewriteUtils.isNonZeroIndicator(W, X) && // V-t(V) constraint
HopRewriteUtils.isTransposeOfItself(right, V)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = right;
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM4 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
if (!appliedPattern && // MULT
HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = right.getInput().get(0);
Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
Hop X = right.getInput().get(1).getInput().get(1);
if (// t(U)-U constraint
HopRewriteUtils.isTransposeOfItself(left, U)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
// note: x and w exchanged compared to patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 1, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
// add output transpose for efficient target indexing (redundant t() removed by other rewrites)
hnew = HopRewriteUtils.createTranspose(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM5 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 6) (W*(U%*%t(V)-X)) %*% V
if (!appliedPattern && // MULT
HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = left.getInput().get(0);
Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
Hop X = left.getInput().get(1).getInput().get(1);
if (// V-t(V) constraint
HopRewriteUtils.isTransposeOfItself(right, V)) {
if (!HopRewriteUtils.isTransposeOperation(V))
V = right;
else
V = V.getInput().get(0);
// note: x and w exchanged compared to patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 2, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM6 (line " + hi.getBeginLine() + ")");
}
}
}
// Pattern 7) (W*(U%*%t(V)))
if (!appliedPattern && // MULT
HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) && // no mmchain
(((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop U = hi.getInput().get(1).getInput().get(0);
Hop V = hi.getInput().get(1).getInput().get(1);
// W is sparse and U/V unknown or dense; or if U/V are dense
if ((HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V))) {
V = !HopRewriteUtils.isTransposeOperation(V) ? HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedDivMM7 (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.QuaternaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.
/**
* Searches for weighted squared loss expressions and replaces them with a quaternary operator.
* Currently, this search includes the following three patterns:
* 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
* 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
* 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
*
* NOTE: We include transpose into the pattern because during runtime we need to compute
* U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
* without additional memory requirements for internal transpose.
*
* This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
* U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
* rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
* as we generalized the runtime or hop/lop compilation.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
// NOTE: there might be also a general simplification without custom operator
// via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
Hop hnew = null;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
boolean appliedPattern = false;
// alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
Hop W = bop.getInput().get(0);
// (X - U %*% t(V))
Hop tmp = bop.getInput().get(1).getInput().get(0);
if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
// a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (W * (U %*% t(V) - X) ^ 2)
if (// ba gurantees matrices
tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
// handle special case of post_nz
if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
W = new LiteralOp(1);
}
// construct quaternary hop
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 1;
} else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 0;
}
if (// rewrite match
wuvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
// (W * (U %*% t(V)))
Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
Hop W = tmp.getInput().get(0);
Hop U = tmp.getInput().get(1).getInput().get(0);
Hop V = tmp.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (((U %*% t(V)) - X) ^ 2)
if (// ba gurantees matrices
lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
// (U %*% t(V))
Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
// no weighting
Hop W = new LiteralOp(1);
Hop U = tmp.getInput().get(0);
Hop V = tmp.getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.QuaternaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.
private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// Pattern 1) sum( X * log(U %*% t(V)))
if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
Hop X = left;
Hop U = right.getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
appliedPattern = true;
LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
}
// Pattern 2) sum( X * log(U %*% t(V) + eps))
if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop X = left;
Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
Hop eps = right.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
false);
hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
Aggregations