use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.
the class SpoofCompiler method cleanupCPlans.
/**
* Cleanup generated cplans in order to remove unnecessary inputs created
* during incremental construction. This is important as it avoids unnecessary
* redundant computation.
*
* @param cplans set of cplans
*/
private static HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[], CNodeTpl>> cplans) {
HashMap<Long, Pair<Hop[], CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[], CNodeTpl>>();
for (Entry<Long, Pair<Hop[], CNodeTpl>> e : cplans.entrySet()) {
CNodeTpl tpl = e.getValue().getValue();
Hop[] inHops = e.getValue().getKey();
//collect cplan leaf node names
HashSet<Long> leafs = new HashSet<Long>();
if (tpl instanceof CNodeMultiAgg)
for (CNode out : ((CNodeMultiAgg) tpl).getOutputs()) rCollectLeafIDs(out, leafs);
else
rCollectLeafIDs(tpl.getOutput(), leafs);
//create clean cplan w/ minimal inputs
if (inHops.length == leafs.size())
cplans2.put(e.getKey(), e.getValue());
else {
tpl.cleanupInputs(leafs);
ArrayList<Hop> tmp = new ArrayList<Hop>();
for (Hop hop : inHops) {
if (hop != null && leafs.contains(hop.getHopID()))
tmp.add(hop);
}
cplans2.put(e.getKey(), new Pair<Hop[], CNodeTpl>(tmp.toArray(new Hop[0]), tpl));
}
//remove invalid plans with column indexing on main input
if (tpl instanceof CNodeCell) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
if (rHasLookupRC1(tpl.getOutput(), in1) || isLookupRC1(tpl.getOutput(), in1)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
}
} else if (tpl instanceof CNodeMultiAgg) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
for (CNode output : ((CNodeMultiAgg) tpl).getOutputs()) if (rHasLookupRC1(output, in1) || isLookupRC1(output, in1)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
}
}
//remove spurious lookups on main input of cell template
if (tpl instanceof CNodeCell || tpl instanceof CNodeOuterProduct) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
rFindAndRemoveLookup(tpl.getOutput(), in1);
} else if (tpl instanceof CNodeMultiAgg) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
rFindAndRemoveLookupMultiAgg((CNodeMultiAgg) tpl, in1);
}
//remove cplan w/ single op and w/o agg
if ((tpl instanceof CNodeCell && ((((CNodeCell) tpl).getCellType() == CellType.NO_AGG && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl))) || tpl instanceof CNodeRow && TemplateUtils.hasSingleOperation(tpl))
cplans2.remove(e.getKey());
//remove cplan if empty
if (tpl.getOutput() instanceof CNodeData)
cplans2.remove(e.getKey());
}
return cplans2;
}
use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.
the class SpoofCompiler method cleanupCPlans.
/**
* Cleanup generated cplans in order to remove unnecessary inputs created
* during incremental construction. This is important as it avoids unnecessary
* redundant computation.
*
* @param memo memoization table
* @param cplans set of cplans
*/
private static HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans(CPlanMemoTable memo, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans) {
HashMap<Long, Pair<Hop[], CNodeTpl>> cplans2 = new HashMap<>();
CPlanOpRewriter rewriter = new CPlanOpRewriter();
CPlanCSERewriter cse = new CPlanCSERewriter();
for (Entry<Long, Pair<Hop[], CNodeTpl>> e : cplans.entrySet()) {
CNodeTpl tpl = e.getValue().getValue();
Hop[] inHops = e.getValue().getKey();
// remove invalid plans with null inputs
if (Arrays.stream(inHops).anyMatch(h -> (h == null)))
continue;
// perform simplifications and cse rewrites
tpl = rewriter.simplifyCPlan(tpl);
tpl = cse.eliminateCommonSubexpressions(tpl);
// update input hops (order-preserving)
HashSet<Long> inputHopIDs = tpl.getInputHopIDs(false);
inHops = Arrays.stream(inHops).filter(p -> p != null && inputHopIDs.contains(p.getHopID())).toArray(Hop[]::new);
cplans2.put(e.getKey(), new Pair<>(inHops, tpl));
// remove invalid plans with column indexing on main input
if (tpl instanceof CNodeCell || tpl instanceof CNodeRow) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
boolean inclRC1 = !(tpl instanceof CNodeRow);
if (rHasLookupRC1(tpl.getOutput(), in1, inclRC1) || isLookupRC1(tpl.getOutput(), in1, inclRC1)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
}
} else if (tpl instanceof CNodeMultiAgg) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
for (CNode output : ((CNodeMultiAgg) tpl).getOutputs()) if (rHasLookupRC1(output, in1, true) || isLookupRC1(output, in1, true)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
}
}
// remove invalid lookups on main input (all templates)
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
if (tpl instanceof CNodeMultiAgg)
rFindAndRemoveLookupMultiAgg((CNodeMultiAgg) tpl, in1);
else
rFindAndRemoveLookup(tpl.getOutput(), in1, !(tpl instanceof CNodeRow));
// remove invalid row templates (e.g., unsatisfied blocksize constraint)
if (tpl instanceof CNodeRow) {
// check for invalid row cplan over column vector
if (((CNodeRow) tpl).getRowType() == RowType.NO_AGG && tpl.getOutput().getDataType().isScalar()) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed invalid row cplan w/o agg on column vector.");
} else if (OptimizerUtils.isSparkExecutionMode()) {
Hop hop = memo.getHopRefs().get(e.getKey());
boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || OptimizerUtils.getTotalMemEstimate(inHops, hop, true) > OptimizerUtils.getLocalMemBudget();
boolean invalidNcol = hop.getDataType().isMatrix() && (HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim1() > hop.getRowsInBlock() : hop.getDim2() > hop.getColsInBlock());
for (Hop in : inHops) invalidNcol |= (in.getDataType().isMatrix() && in.getDim2() > in.getColsInBlock());
if (isSpark && invalidNcol) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed invalid row cplan w/ ncol>ncolpb.");
}
}
}
// remove cplan w/ single op and w/o agg
if ((tpl instanceof CNodeCell && ((CNodeCell) tpl).getCellType() == CellType.NO_AGG && TemplateUtils.hasSingleOperation(tpl)) || (tpl instanceof CNodeRow && (((CNodeRow) tpl).getRowType() == RowType.NO_AGG || ((CNodeRow) tpl).getRowType() == RowType.NO_AGG_B1 || ((CNodeRow) tpl).getRowType() == RowType.ROW_AGG) && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan with single operation.");
}
// remove cplan if empty
if (tpl.getOutput() instanceof CNodeData) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed empty cplan.");
}
// rename inputs (for codegen and plan caching)
tpl.renameInputs();
}
return cplans2;
}
use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.
the class CPlanCSERewriter method rSetStrictDataNodeComparision.
private void rSetStrictDataNodeComparision(CNode current, boolean flag) {
// avoid redundant re-evaluation
if (current.isVisited())
return;
// process inputs recursively and node itself
for (CNode input : current.getInput()) {
rSetStrictDataNodeComparision(input, flag);
input.resetHash();
}
if (current instanceof CNodeData)
((CNodeData) current).setStrictEquals(flag);
current.setVisited();
}
use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
// recursively process required childs
if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
i == 0)
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
// construct binary cnode
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
// correct indexing types of existing lookups
if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
TemplateUtils.rFlipVectorLookups(out);
// maintain input hops
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof AggUnaryOp) {
// aggregation handled in template implementation (note: we do not compile
// ^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
// (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
inHops.add(hop.getInput().get(0).getInput().get(0));
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.codegen.cplan.CNodeData in project incubator-systemml by apache.
the class TemplateOuterProduct method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
// recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OUTER, TemplateType.CELL);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (HopRewriteUtils.isBinarySparseSafe(hop)) {
if (TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData)
inHops2.put("_X", hop.getInput().get(0));
if (TemplateUtils.isMatrix(hop.getInput().get(1)) && cdata2 instanceof CNodeData)
inHops2.put("_X", hop.getInput().get(1));
}
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// handle transpose in outer or final product
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
// outer product U%*%t(V), see open
if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
// keep U and V for later reference
inHops2.put("_U", hop.getInput().get(0));
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
else
inHops2.put("_V", hop.getInput().get(1));
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
} else // final left/right matrix mult, see close
{
if (cdata1.getDataType().isScalar())
out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
else
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
out = tmp.get(hop.getInput().get(0).getHopID());
}
tmp.put(hop.getHopID(), out);
}
Aggregations