use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.
the class TemplateCell method isValidOperation.
protected static boolean isValidOperation(Hop hop) {
// prepare indicators for binary operations
boolean isBinaryMatrixScalar = false;
boolean isBinaryMatrixVector = false;
boolean isBinaryMatrixMatrix = false;
if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(1);
DataType ldt = left.getDataType();
DataType rdt = right.getDataType();
isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar());
isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)));
isBinaryMatrixMatrix = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix();
}
// prepare indicators for ternary operations
boolean isTernaryVectorScalarVector = false;
boolean isTernaryMatrixScalarMatrixDense = false;
boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix());
if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(2);
isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
}
// check supported unary, binary, ternary operations
return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.REPLACE));
}
use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
// recursively process required childs
if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
i == 0)
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
// construct binary cnode
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
// correct indexing types of existing lookups
if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
TemplateUtils.rFlipVectorLookups(out);
// maintain input hops
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof AggUnaryOp) {
// aggregation handled in template implementation (note: we do not compile
// ^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
// (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
inHops.add(hop.getInput().get(0).getInput().get(0));
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.
the class HopRewriteUtils method createTernaryOp.
public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) {
TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright);
ternOp.setOutputBlocksizes(mleft.getRowsInBlock(), mleft.getColsInBlock());
copyLineNumbers(mleft, ternOp);
ternOp.refreshSizeInformation();
return ternOp;
}
use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifyReverseOperation.
/**
* NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static
* rewrite in order to apply it before splitting dags which would hide the table information
* if dimensions are not specified.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyReverseOperation(Hop parent, Hop hi, int pos) {
if (hi instanceof AggBinaryOp && hi.getInput().get(0) instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hi.getInput().get(0);
if (top.getOp() == OpOp3.CTABLE && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) && top.getInput().get(0).getDim1() == top.getInput().get(1).getDim1()) {
ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
HopRewriteUtils.replaceChildReference(parent, hi, rop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, top);
hi = rop;
LOG.debug("Applied simplifyReverseOperation.");
}
}
return hi;
}
use of org.apache.sysml.hops.TernaryOp in project systemml by apache.
the class HopRewriteUtils method createTernaryOp.
public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) {
TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright);
ternOp.setOutputBlocksizes(mleft.getRowsInBlock(), mleft.getColsInBlock());
copyLineNumbers(mleft, ternOp);
ternOp.refreshSizeInformation();
return ternOp;
}
Aggregations