use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryMinus.
private Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) throws HopsException {
if (hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp && //first minus
((BinaryOp) hi).getOp() == OpOp2.MINUS && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp) hi.getInput().get(0)).getDoubleValue() == 0) {
Hop hi2 = hi.getInput().get(1);
if (hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp && //second minus
((BinaryOp) hi2).getOp() == OpOp2.MINUS && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0)).getDoubleValue() == 0) {
Hop hi3 = hi2.getInput().get(1);
//remove unnecessary chain of -(-())
HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = hi3;
LOG.debug("Applied removeUnecessaryMinus");
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseLogNzUnaryOperation.
private Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos) throws HopsException {
//memory estimate and to prevent dense intermediates if X is ultra sparse
if (HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG)) {
Hop pred = hi.getInput().get(0);
Hop X = hi.getInput().get(1).getInput().get(0);
if (HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && //depend on common subexpression elimination
pred.getInput().get(0) == X && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred.getInput().get(1)) == 0) {
Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ);
//relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied fuseLogNzUnaryOperation (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndMinusOperation.
private Hop fuseDatagenAndMinusOperation(Hop hi) throws HopsException {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == 0.0) {
DataGenOp inputGen = (DataGenOp) right;
HashMap<String, Integer> params = inputGen.getParamIndexMap();
Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
int ixMin = params.get(DataExpression.RAND_MIN);
int ixMax = params.get(DataExpression.RAND_MAX);
Hop min = right.getInput().get(ixMin);
Hop max = right.getInput().get(ixMax);
//apply rewrite under additional conditions (for simplicity)
if (inputGen.getParent().size() == 1 && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue())) {
//exchange and *-1 (special case 0 stays 0 instead of -0 for consistency)
double newMinVal = (((LiteralOp) max).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) max).getDoubleValue());
double newMaxVal = (((LiteralOp) min).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) min).getDoubleValue());
Hop newMin = new LiteralOp(newMinVal);
Hop newMax = new LiteralOp(newMaxVal);
HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin);
HopRewriteUtils.addChildReference(inputGen, newMin, ixMin);
HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax);
HopRewriteUtils.addChildReference(inputGen, newMax, ixMax);
//rewire all parents (avoid anomalies with replicated datagen)
List<Hop> parents = new ArrayList<Hop>(bop.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, inputGen);
hi = inputGen;
LOG.debug("Applied fuseDatagenAndMinusOperation (line " + bop.getBeginLine() + ").");
}
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteCommonSubexpressionElimination method rule_CommonSubexpressionElimination_MergeLeafs.
private int rule_CommonSubexpressionElimination_MergeLeafs(Hop hop, HashMap<String, Hop> dataops, HashMap<String, Hop> literalops) throws HopsException {
int ret = 0;
if (hop.isVisited())
return ret;
if (//LEAF NODE
hop.getInput().isEmpty()) {
if (hop instanceof LiteralOp) {
String key = hop.getValueType() + "_" + hop.getName();
if (!literalops.containsKey(key))
literalops.put(key, hop);
} else if (hop instanceof DataOp && ((DataOp) hop).isRead()) {
if (!dataops.containsKey(hop.getName()))
dataops.put(hop.getName(), hop);
}
} else //INNER NODE
{
//merge leaf nodes (data, literal)
for (int i = 0; i < hop.getInput().size(); i++) {
Hop hi = hop.getInput().get(i);
String litKey = hi.getValueType() + "_" + hi.getName();
if (hi instanceof DataOp && ((DataOp) hi).isRead() && dataops.containsKey(hi.getName())) {
//replace child node ref
Hop tmp = dataops.get(hi.getName());
if (tmp != hi) {
//if required
tmp.getParent().add(hop);
hop.getInput().set(i, tmp);
ret++;
}
} else if (hi instanceof LiteralOp && literalops.containsKey(litKey)) {
Hop tmp = literalops.get(litKey);
//replace child node ref
if (tmp != hi) {
//if required
tmp.getParent().add(hop);
hop.getInput().set(i, tmp);
ret++;
}
}
//recursive invocation (direct return on merged nodes)
ret += rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops);
}
}
hop.setVisited();
return ret;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.
/**
* Searches for weighted squared loss expressions and replaces them with a quaternary operator.
* Currently, this search includes the following three patterns:
* 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
* 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
* 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
*
* NOTE: We include transpose into the pattern because during runtime we need to compute
* U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
* without additional memory requirements for internal transpose.
*
* This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
* U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
* rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
* as we generalized the runtime or hop/lop compilation.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
* @throws HopsException if HopsException occurs
*/
private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) throws HopsException {
//NOTE: there might be also a general simplification without custom operator
//via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
Hop hnew = null;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && //all patterns rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && //all patterns subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && //not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
boolean appliedPattern = false;
//alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
Hop W = bop.getInput().get(0);
//(X - U %*% t(V))
Hop tmp = bop.getInput().get(1).getInput().get(0);
if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && //prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
//a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
if (//ba gurantees matrices
tmp.getInput().get(1) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
uvIndex = 1;
} else //b) sum (W * (U %*% t(V) - X) ^ 2)
if (//ba gurantees matrices
tmp.getInput().get(0) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
uvIndex = 0;
}
if (//rewrite match
uvIndex >= 0) {
Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
//handle special case of post_nz
if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
W = new LiteralOp(1);
}
//construct quaternary hop
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
//alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
//a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 1;
} else //b) sum ((W * (U %*% t(V)) - X) ^ 2)
if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 0;
}
if (//rewrite match
wuvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
//(W * (U %*% t(V)))
Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && //BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
Hop W = tmp.getInput().get(0);
Hop U = tmp.getInput().get(1).getInput().get(0);
Hop V = tmp.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
//alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
//a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if (//ba gurantees matrices
lright instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
uvIndex = 1;
} else //b) sum (((U %*% t(V)) - X) ^ 2)
if (//ba gurantees matrices
lleft instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
uvIndex = 0;
}
if (//rewrite match
uvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
//(U %*% t(V))
Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
//no weighting
Hop W = new LiteralOp(1);
Hop U = tmp.getInput().get(0);
Hop V = tmp.getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
//relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
Aggregations