use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
//memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl);
//recursively process required childs
if (me != null && (me.type == TemplateType.RowTpl || me.type == TemplateType.OuterProdTpl)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
//construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
//add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
if (bop.getOp() == OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2"))
out = new CNodeUnary(cdata1, UnaryType.POW2);
else if (bop.getOp() == OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2"))
out = new CNodeUnary(cdata1, UnaryType.MULT2);
else
//default binary
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
//add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
//construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggUnaryOp) {
//aggregation handled in template implementation (note: we do not compile
//^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
//(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateCell method constructCplan.
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
//recursively process required cplan output
HashSet<Hop> inHops = new HashSet<Hop>();
HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
hop.resetVisitStatus();
//reorder inputs (ensure matrices/vectors come first) and prune literals
//note: we order by number of cells and subsequently sparsity to ensure
//that sparse inputs are used as the main input w/o unnecessary conversion
List<Hop> sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).collect(Collectors.toList());
//construct template node
ArrayList<CNode> inputs = new ArrayList<CNode>();
for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeCell tpl = new CNodeCell(inputs, output);
tpl.setCellType(TemplateUtils.getCellType(hop));
tpl.setAggOp(TemplateUtils.getAggOp(hop));
tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops.get(0))) || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0)));
tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
// return cplan instance
return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateCell method isValidOperation.
protected static boolean isValidOperation(Hop hop) {
//prepare indicators for binary operations
boolean isBinaryMatrixScalar = false;
boolean isBinaryMatrixVector = false;
boolean isBinaryMatrixMatrixDense = false;
if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(1);
DataType ldt = left.getDataType();
DataType rdt = right.getDataType();
isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar());
isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)));
isBinaryMatrixMatrixDense = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix() && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
}
//prepare indicators for ternary operations
boolean isTernaryVectorScalarVector = false;
boolean isTernaryMatrixScalarMatrixDense = false;
if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(2);
isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
}
//check supported unary, binary, ternary operations
return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.REPLACE));
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateOuterProduct method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
//memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
//recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OuterProdTpl);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
//construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (HopRewriteUtils.isEqualSize(hop.getInput().get(0), hop.getInput().get(1))) {
Hop main = hop.getInput().get((cdata1 instanceof CNodeData) ? 0 : 1);
inHops2.put("_X", main);
}
//add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
//handle transpose in outer or final product
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
//outer product U%*%t(V), see open
if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
//keep U and V for later reference
inHops2.put("_U", hop.getInput().get(0));
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
else
inHops2.put("_V", hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
} else //final left/right matrix mult, see close
{
if (cdata1.getDataType().isScalar())
out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
else
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
out = tmp.get(hop.getInput().get(0).getHopID());
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateOuterProduct method constructCplan.
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
//recursively process required cplan output
HashSet<Hop> inHops = new HashSet<Hop>();
HashMap<String, Hop> inHops2 = new HashMap<String, Hop>();
HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
hop.resetVisitStatus();
//reorder inputs (ensure matrix is first input)
Hop X = inHops2.get("_X");
Hop U = inHops2.get("_U");
Hop V = inHops2.get("_V");
LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops);
sinHops.remove(V);
sinHops.addFirst(V);
sinHops.remove(U);
sinHops.addFirst(U);
sinHops.remove(X);
sinHops.addFirst(X);
//construct template node
ArrayList<CNode> inputs = new ArrayList<CNode>();
for (Hop in : sinHops) if (in != null)
inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output);
tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop));
tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) && tpl.getOutProdType() == OutProdType.LEFT_OUTER_PRODUCT);
return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
Aggregations