use of org.apache.sysml.hops.ParameterizedBuiltinOp in project incubator-systemml by apache.
the class TemplateCell method isValidOperation.
protected static boolean isValidOperation(Hop hop) {
// prepare indicators for binary operations
boolean isBinaryMatrixScalar = false;
boolean isBinaryMatrixVector = false;
boolean isBinaryMatrixMatrix = false;
if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(1);
DataType ldt = left.getDataType();
DataType rdt = right.getDataType();
isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar());
isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)));
isBinaryMatrixMatrix = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix();
}
// prepare indicators for ternary operations
boolean isTernaryVectorScalarVector = false;
boolean isTernaryMatrixScalarMatrixDense = false;
boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix());
if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(2);
isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
}
// check supported unary, binary, ternary operations
return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.REPLACE));
}
use of org.apache.sysml.hops.ParameterizedBuiltinOp in project incubator-systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
// recursively process required childs
if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
i == 0)
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
// construct binary cnode
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
// correct indexing types of existing lookups
if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
TemplateUtils.rFlipVectorLookups(out);
// maintain input hops
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof AggUnaryOp) {
// aggregation handled in template implementation (note: we do not compile
// ^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
// (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
inHops.add(hop.getInput().get(0).getInput().get(0));
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.ParameterizedBuiltinOp in project incubator-systemml by apache.
the class DMLTranslator method processParameterizedBuiltinFunctionExpression.
/**
* Construct Hops from parse tree : Process ParameterizedBuiltinFunction Expression in an
* assignment statement
*
* @param source parameterized built-in function
* @param target data identifier
* @param hops map of high-level operators
* @return high-level operator
*/
private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
// this expression has multiple "named" parameters
HashMap<String, Hop> paramHops = new HashMap<>();
// -- construct hops for all input parameters
// -- store them in hashmap so that their "name"s are maintained
Hop pHop = null;
for (String paramName : source.getVarParams().keySet()) {
pHop = processExpression(source.getVarParam(paramName), null, hops);
paramHops.put(paramName, pHop);
}
Hop currBuiltinOp = null;
if (target == null) {
target = createTarget(source);
}
// construct hop based on opcode
switch(source.getOpCode()) {
case CDF:
case INVCDF:
case QNORM:
case QT:
case QF:
case QCHISQ:
case QEXP:
case PNORM:
case PT:
case PF:
case PCHISQ:
case PEXP:
currBuiltinOp = constructDfHop(target.getName(), target.getDataType(), target.getValueType(), source.getOpCode(), paramHops);
break;
case GROUPEDAGG:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.GROUPEDAGG, paramHops);
break;
case RMEMPTY:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.RMEMPTY, paramHops);
break;
case REPLACE:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.REPLACE, paramHops);
break;
case ORDER:
ArrayList<Hop> inputs = new ArrayList<>();
inputs.add(paramHops.get("target"));
inputs.add(paramHops.get("by"));
inputs.add(paramHops.get("decreasing"));
inputs.add(paramHops.get("index.return"));
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.SORT, inputs);
break;
case TRANSFORMAPPLY:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.TRANSFORMAPPLY, paramHops);
break;
case TRANSFORMDECODE:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.TRANSFORMDECODE, paramHops);
break;
case TRANSFORMCOLMAP:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.TRANSFORMCOLMAP, paramHops);
break;
case TRANSFORMMETA:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.TRANSFORMMETA, paramHops);
break;
case TOSTRING:
// check for input data type and only compile toString Hop for matrices/frames,
// for scalars, we compile (s + "") to ensure consistent string output value types
currBuiltinOp = !paramHops.get("target").getDataType().isScalar() ? new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.TOSTRING, paramHops) : HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), OpOp2.PLUS);
break;
default:
throw new ParseException(source.printErrorLocation() + "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode());
}
setIdentifierParams(currBuiltinOp, source.getOutput());
currBuiltinOp.setParseInfo(source);
return currBuiltinOp;
}
use of org.apache.sysml.hops.ParameterizedBuiltinOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyGroupedAggregate.
private static Hop simplifyGroupedAggregate(Hop hi) {
if (// aggregate
hi instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hi).getOp() == ParamBuiltinOp.GROUPEDAGG) {
ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp) hi;
if (// aggregate(fn="count")
phi.isCountFunction() && // only for vector
phi.getTargetHop().getDim2() == 1) {
HashMap<String, Integer> params = phi.getParamIndexMap();
int ix1 = params.get(Statement.GAGG_TARGET);
int ix2 = params.get(Statement.GAGG_GROUPS);
// check for unnecessary memory consumption for "count"
if (ix1 != ix2 && phi.getInput().get(ix1) != phi.getInput().get(ix2)) {
Hop th = phi.getInput().get(ix1);
Hop gh = phi.getInput().get(ix2);
HopRewriteUtils.replaceChildReference(hi, th, gh, ix1);
LOG.debug("Applied simplifyGroupedAggregateCount");
}
}
}
return hi;
}
use of org.apache.sysml.hops.ParameterizedBuiltinOp in project systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyOuterSeqExpand.
private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) {
if (HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp) hi).isOuterVectorOperator()) {
if ((// pattern a: outer(v, t(seq(1,m)), "==")
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) || // pattern b: outer(seq(1,m), t(v) "==")
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) {
// determine variable parameters for pattern a/b
boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0));
boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));
Hop trgt = isPatternB ? (isTransposeRight ? // get v from t(v)
hi.getInput().get(1).getInput().get(0) : // create v via t(v')
HopRewriteUtils.createTranspose(hi.getInput().get(1))) : // get v directly
hi.getInput().get(0);
Hop seq = isPatternB ? hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0);
String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols";
// setup input parameter hops
HashMap<String, Hop> inputargs = new HashMap<>();
inputargs.put("target", trgt);
inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMax(seq));
inputargs.put("dir", new LiteralOp(direction));
inputargs.put("ignore", new LiteralOp(true));
inputargs.put("cast", new LiteralOp(false));
// create new hop
ParameterizedBuiltinOp pbop = HopRewriteUtils.createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND);
// relink new hop into original position
HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
hi = pbop;
LOG.debug("Applied simplifyOuterSeqExpand (line " + hi.getBeginLine() + ")");
}
}
return hi;
}
Aggregations