use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeScalarAggregate.
private static StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
StatementBlock ret = sb;
// check missing and supported increment values
if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
// check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
// row or col
boolean rowIx = false;
if (csb.getHops() != null && csb.getHops().size() == 1) {
Hop root = csb.getHops().get(0);
if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// check for left scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) right.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
leftScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
leftScalar = true;
rowIx = false;
}
} else // check for right scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) left.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
rightScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
rightScalar = true;
rowIx = false;
}
}
}
}
// apply rewrite if possible
if (leftScalar || rightScalar) {
Hop root = csb.getHops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
// replace cast with sum
AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
// modify indexing expression according to loop predicate from-to
// NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
// update indexing size information
if (rowIx)
((IndexingOp) ix).setRowLowerEqualsUpper(false);
else
((IndexingOp) ix).setColLowerEqualsUpper(false);
ix.refreshSizeInformation();
ret = csb;
LOG.debug("Applied vectorizeScalarSumForLoop.");
}
return ret;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class LiteralReplacement method replaceLiteralFullUnaryAggregateRightIndexing.
private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing(Hop c, LocalVariableMap vars) {
LiteralOp ret = null;
// full unary aggregate w/ indexed matrix less than 10^6 cells
if (c instanceof AggUnaryOp && isReplaceableUnaryAggregate((AggUnaryOp) c) && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getInput().get(0) instanceof DataOp) {
IndexingOp rix = (IndexingOp) c.getInput().get(0);
Hop data = rix.getInput().get(0);
Hop rl = rix.getInput().get(1);
Hop ru = rix.getInput().get(2);
Hop cl = rix.getInput().get(3);
Hop cu = rix.getInput().get(4);
if (data instanceof DataOp && vars.keySet().contains(data.getName()) && isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) && isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars)) {
long rlval = getIntValueDataLiteral(rl, vars);
long ruval = getIntValueDataLiteral(ru, vars);
long clval = getIntValueDataLiteral(cl, vars);
long cuval = getIntValueDataLiteral(cu, vars);
MatrixObject mo = (MatrixObject) vars.get(data.getName());
// dimensions might not have been updated during recompile
if (mo.getNumRows() * mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) {
MatrixBlock mBlock = mo.acquireRead();
MatrixBlock mBlock2 = mBlock.slice((int) (rlval - 1), (int) (ruval - 1), (int) (clval - 1), (int) (cuval - 1), new MatrixBlock());
double value = replaceUnaryAggregate((AggUnaryOp) c, mBlock2);
mo.release();
// literal substitution (always double)
ret = new LiteralOp(value);
}
}
}
return ret;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class HopRewriteUtils method createAggUnaryOp.
public static AggUnaryOp createAggUnaryOp(Hop input, AggOp op, Direction dir) {
DataType dt = (dir == Direction.RowCol) ? DataType.SCALAR : input.getDataType();
AggUnaryOp auop = new AggUnaryOp(input.getName(), dt, input.getValueType(), op, dir, input);
auop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
copyLineNumbers(input, auop);
auop.refreshSizeInformation();
return auop;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.
private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// Pattern 1) sum( X * log(U %*% t(V)))
if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
Hop X = left;
Hop U = right.getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
appliedPattern = true;
LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
}
// Pattern 2) sum( X * log(U %*% t(V) + eps))
if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop X = left;
Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
Hop eps = right.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
false);
hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyUnnecessaryAggregate.
private static Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) {
// e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace)
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE)) {
if (input.getDim1() == 1 && input.getDim2() == 1) {
UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
// remove unnecessary aggregation
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
hi = cast;
LOG.debug("Applied simplifyUnncessaryAggregate");
}
}
}
return hi;
}
Aggregations