use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifySlicedMatrixMult.
private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) {
// e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
if (hi instanceof IndexingOp && ((IndexingOp) hi).isRowLowerEqualsUpper() && ((IndexingOp) hi).isColLowerEqualsUpper() && // rix is single mm consumer
hi.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0))) {
Hop mm = hi.getInput().get(0);
Hop X = mm.getInput().get(0);
Hop Y = mm.getInput().get(1);
// rl==ru
Hop rowExpr = hi.getInput().get(1);
// cl==cu
Hop colExpr = hi.getInput().get(3);
HopRewriteUtils.removeAllChildReferences(mm);
// create new indexing operations
IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false);
ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
ix1.refreshSizeInformation();
IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock());
ix2.refreshSizeInformation();
// rewire matrix mult over ix1 and ix2
HopRewriteUtils.addChildReference(mm, ix1, 0);
HopRewriteUtils.addChildReference(mm, ix2, 1);
mm.refreshSizeInformation();
hi = mm;
LOG.debug("Applied simplifySlicedMatrixMult");
}
return hi;
}
use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rGetComputeCosts.
private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
if (computeCosts.containsKey(current.getHopID()))
return;
//recursively process children
for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
//get costs for given hop
double costs = 1;
if (current instanceof UnaryOp) {
switch(((UnaryOp) current).getOp()) {
case ABS:
case ROUND:
case CEIL:
case FLOOR:
case SIGN:
case SELP:
costs = 1;
break;
case SPROP:
case SQRT:
costs = 2;
break;
case EXP:
costs = 18;
break;
case SIGMOID:
costs = 21;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case NCOL:
case NROW:
case PRINT:
case CAST_AS_BOOLEAN:
case CAST_AS_DOUBLE:
case CAST_AS_INT:
case CAST_AS_MATRIX:
case CAST_AS_SCALAR:
costs = 1;
break;
case SIN:
costs = 18;
break;
case COS:
costs = 22;
break;
case TAN:
costs = 42;
break;
case ASIN:
costs = 93;
break;
case ACOS:
costs = 103;
break;
case ATAN:
costs = 40;
break;
case CUMSUM:
case CUMMIN:
case CUMMAX:
case CUMPROD:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
}
} else if (current instanceof BinaryOp) {
switch(((BinaryOp) current).getOp()) {
case MULT:
case PLUS:
case MINUS:
case MIN:
case MAX:
case AND:
case OR:
case EQUAL:
case NOTEQUAL:
case LESS:
case LESSEQUAL:
case GREATER:
case GREATEREQUAL:
case CBIND:
case RBIND:
costs = 1;
break;
case INTDIV:
costs = 6;
break;
case MODULUS:
costs = 8;
break;
case DIV:
costs = 22;
break;
case LOG:
case LOG_NZ:
costs = 32;
break;
case POW:
costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
break;
case MINUS_NZ:
case MINUS1_MULT:
costs = 2;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
//count
case 0:
costs = 1;
break;
//mean
case 1:
costs = 8;
break;
//cm2
case 2:
costs = 16;
break;
//cm3
case 3:
costs = 31;
break;
//cm4
case 4:
costs = 51;
break;
//variance
case 5:
costs = 16;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
}
} else if (current instanceof TernaryOp) {
switch(((TernaryOp) current).getOp()) {
case PLUS_MULT:
case MINUS_MULT:
costs = 2;
break;
case CTABLE:
costs = 3;
break;
case CENTRALMOMENT:
int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
switch(type) {
//count
case 0:
costs = 2;
break;
//mean
case 1:
costs = 9;
break;
//cm2
case 2:
costs = 17;
break;
//cm3
case 3:
costs = 32;
break;
//cm4
case 4:
costs = 52;
break;
//variance
case 5:
costs = 17;
break;
}
break;
case COVARIANCE:
costs = 23;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
}
} else if (current instanceof ParameterizedBuiltinOp) {
costs = 1;
} else if (current instanceof IndexingOp) {
costs = 1;
} else if (current instanceof ReorgOp) {
costs = 1;
} else if (current instanceof AggBinaryOp) {
//matrix vector
costs = 2;
} else if (current instanceof AggUnaryOp) {
switch(((AggUnaryOp) current).getOp()) {
case SUM:
costs = 4;
break;
case SUM_SQ:
costs = 5;
break;
case MIN:
case MAX:
costs = 1;
break;
default:
LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
}
}
computeCosts.put(current.getHopID(), costs);
}
Aggregations