use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyUnnecessaryAggregate.
private static Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) {
// e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace)
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE)) {
if (input.getDim1() == 1 && input.getDim2() == 1) {
UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
// remove unnecessary aggregation
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
hi = cast;
LOG.debug("Applied simplifyUnncessaryAggregate");
}
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyScalarMVBinaryOperation.
private static Hop simplifyScalarMVBinaryOperation(Hop hi) {
if (// e.g., X * s
hi instanceof BinaryOp && ((BinaryOp) hi).supportsMatrixScalarOperations() && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(1).getDataType() == DataType.MATRIX) {
Hop right = hi.getInput().get(1);
// X * s -> X * as.scalar(s)
if (// scalar right
HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1) {
// remove link to right child and introduce cast
UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
HopRewriteUtils.replaceChildReference(hi, right, cast, 1);
LOG.debug("Applied simplifyScalarMVBinaryOperation.");
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyEmptyUnaryOperation.
private static Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof UnaryOp) {
UnaryOp uhi = (UnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_UNARY)) {
if (HopRewriteUtils.isEmpty(input)) {
// create literal add it to parent
Hop hnew = HopRewriteUtils.createDataGenOp(input, 0);
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
LOG.debug("Applied simplifyEmptyUnaryOperation");
}
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyRowwiseAggregate.
@SuppressWarnings("unchecked")
private static Hop simplifyRowwiseAggregate(Hop parent, Hop hi, int pos) {
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
if (uhi.getDirection() == Direction.Row) {
if (input.getDim2() == 1) {
if (uhi.getOp() == AggOp.VAR) {
// For the row variance aggregation, if the input is a column vector,
// the row variances will each be zero.
// Therefore, perform a rewrite from ROWVAR(X) to a column vector of
// zeros.
Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0);
HopRewriteUtils.replaceChildReference(parent, hi, emptyCol, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);
// replace current HOP with new empty column HOP
hi = emptyCol;
LOG.debug("Applied simplifyRowwiseAggregate for rowVars");
} else {
// All other valid row aggregations over a column vector will result
// in the column vector itself.
// Therefore, remove unnecessary row aggregation for 1 col
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied simplifyRowwiseAggregate1");
}
} else if (input.getDim1() == 1) {
// get old parents (before creating cast over aggregate)
ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
// simplify row-aggregate to full aggregate
uhi.setDirection(Direction.RowCol);
uhi.setDataType(DataType.SCALAR);
// create cast to keep same output datatype
UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
// rehang cast under all parents
for (Hop p : parents) {
int ix = HopRewriteUtils.getChildReferencePos(p, hi);
HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
}
hi = cast;
LOG.debug("Applied simplifyRowwiseAggregate2");
}
}
}
}
return hi;
}
use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyBinaryMatrixScalarOperation.
private static Hop simplifyBinaryMatrixScalarOperation(Hop parent, Hop hi, int pos) {
if (HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) && hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) {
BinaryOp bin = (BinaryOp) hi.getInput().get(0);
BinaryOp bout = null;
// as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
if (bin.getInput().get(0).getDataType() == DataType.MATRIX && bin.getInput().get(1).getDataType() == DataType.MATRIX) {
UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp());
} else // as.scalar(X*s) -> as.scalar(X) * s
if (bin.getInput().get(0).getDataType() == DataType.MATRIX) {
UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp());
} else // as.scalar(s*X) -> s * as.scalar(X)
if (bin.getInput().get(1).getDataType() == DataType.MATRIX) {
UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
}
if (bout != null) {
HopRewriteUtils.replaceChildReference(parent, hi, bout, pos);
LOG.debug("Applied simplifyBinaryMatrixScalarOperation.");
}
}
return hi;
}
Aggregations