use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyNnzComputation.
private static Hop simplifyNnzComputation(Hop parent, Hop hi, int pos) {
// sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
if (// sum
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
((AggUnaryOp) hi).getDirection() == Direction.RowCol && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.NOTEQUAL)) {
Hop ppred = hi.getInput().get(0);
Hop X = null;
if (ppred.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) ppred.getInput().get(0)) == 0) {
X = ppred.getInput().get(1);
} else if (ppred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) ppred.getInput().get(1)) == 0) {
X = ppred.getInput().get(0);
}
// apply rewrite if known nnz
if (X != null && X.getNnz() > 0) {
Hop hnew = new LiteralOp(X.getNnz());
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = hnew;
LOG.debug("Applied simplifyNnzComputation.");
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyRowSumsMVMult.
private static Hop simplifyRowSumsMVMult(Hop parent, Hop hi, int pos) {
// removed by other rewrite if unnecessary, i.e., if Y==t(Z)
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (// rowsums
uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row && // b(*)
HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
Hop left = input.getInput().get(0);
Hop right = input.getInput().get(1);
if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() == 1 && // MV (row vector)
right.getDim2() > 1) {
// create new operators
ReorgOp trans = HopRewriteUtils.createTranspose(right);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
// relink new child
HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);
hi = mmult;
LOG.debug("Applied simplifyRowSumsMVMult");
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyTraceMatrixMult.
private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) {
if (// trace()
hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
Hop hi2 = hi.getInput().get(0);
if (// X%*%Y
HopRewriteUtils.isMatrixMultiply(hi2)) {
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
// create new operators (incl refresh size inside for transpose)
ReorgOp trans = HopRewriteUtils.createTranspose(right);
BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp sum = HopRewriteUtils.createSum(mult);
// rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
hi = sum;
LOG.debug("Applied simplifyTraceMatrixMult");
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryAggregates.
private static Hop removeUnnecessaryAggregates(Hop hi) {
// sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X)
if (hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && hi.getInput().get(0).getParent().size() == 1) {
AggUnaryOp au1 = (AggUnaryOp) hi;
AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0);
if ((au1.getOp() == AggOp.SUM && (au2.getOp() == AggOp.SUM || au2.getOp() == AggOp.SUM_SQ)) || (au1.getOp() == AggOp.MIN && au2.getOp() == AggOp.MIN) || (au1.getOp() == AggOp.MAX && au2.getOp() == AggOp.MAX)) {
Hop input = au2.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(au2);
HopRewriteUtils.replaceChildReference(au1, au2, input);
if (au2.getOp() == AggOp.SUM_SQ)
au1.setOp(AggOp.SUM_SQ);
LOG.debug("Applied removeUnnecessaryAggregates (line " + hi.getBeginLine() + ").");
}
}
return hi;
}
use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rGetComputeCosts.
private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
if (computeCosts.containsKey(current.getHopID()))
return;
//recursively process children
for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
//get costs for given hop
double costs = 1;
if (current instanceof UnaryOp) {
switch(((UnaryOp) current).getOp()) {
case ABS:
case ROUND:
case CEIL:
case FLOOR:
case SIGN:
case SELP:
costs = 1;
break;
case SPROP:
case SQRT:
costs = 2;
break;
case EXP:
costs = 18;
break;
case SIGMOID:
costs = 21;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case NCOL:
case NROW:
case PRINT:
case CAST_AS_BOOLEAN:
case CAST_AS_DOUBLE:
case CAST_AS_INT:
case CAST_AS_MATRIX:
case CAST_AS_SCALAR:
costs = 1;
break;
case SIN:
costs = 18;
break;
case COS:
costs = 22;
break;
case TAN:
costs = 42;
break;
case ASIN:
costs = 93;
break;
case ACOS:
costs = 103;
break;
case ATAN:
costs = 40;
break;
case CUMSUM:
case CUMMIN:
case CUMMAX:
case CUMPROD:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
}
} else if (current instanceof BinaryOp) {
switch(((BinaryOp) current).getOp()) {
case MULT:
case PLUS:
case MINUS:
case MIN:
case MAX:
case AND:
case OR:
case EQUAL:
case NOTEQUAL:
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case CBIND:
case RBIND:
costs = 1;
break;
case INTDIV:
costs = 6;
break;
case MODULUS:
costs = 8;
break;
case DIV:
costs = 22;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case POW:
costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
break;
case MINUS_NZ:
case MINUS1_MULT:
costs = 2;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
//count
case 0:
costs = 1;
break;
//mean
case 1:
costs = 8;
break;
//cm2
case 2:
costs = 16;
break;
//cm3
case 3:
costs = 31;
break;
//cm4
case 4:
costs = 51;
break;
//variance
case 5:
costs = 16;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
}
} else if (current instanceof TernaryOp) {
switch(((TernaryOp) current).getOp()) {
case PLUS_MULT:
case MINUS_MULT:
costs = 2;
break;
case CTABLE:
costs = 3;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
//count
case 0:
costs = 2;
break;
//mean
case 1:
costs = 9;
break;
//cm2
case 2:
costs = 17;
break;
//cm3
case 3:
costs = 32;
break;
//cm4
case 4:
costs = 52;
break;
//variance
case 5:
costs = 17;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
}
} else if (current instanceof ParameterizedBuiltinOp) {
costs = 1;
} else if (current instanceof IndexingOp) {
costs = 1;
} else if (current instanceof ReorgOp) {
costs = 1;
} else if (current instanceof AggBinaryOp) {
//matrix vector
costs = 2;
} else if (current instanceof AggUnaryOp) {
switch(((AggUnaryOp) current).getOp()) {
case SUM:
costs = 4;
break;
case SUM_SQ:
costs = 5;
break;
case MIN:
case MAX:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
}
}
computeCosts.put(current.getHopID(), costs);
}
Aggregations