use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class Explain method explainCNode.
// ////////////
// internal explain CNODE
private static String explainCNode(CNode cnode, int level) {
if (cnode.isVisited())
return "";
StringBuilder sb = new StringBuilder();
String offset = createOffset(level);
for (CNode input : cnode.getInput()) sb.append(explainCNode(input, level));
// indentation
sb.append(offset);
// hop id
if (SHOW_DATA_DEPENDENCIES)
sb.append("(" + cnode.getID() + ") ");
// operation string
sb.append(cnode.toString());
// input hop references
if (SHOW_DATA_DEPENDENCIES) {
StringBuilder childs = new StringBuilder();
childs.append(" (");
boolean childAdded = false;
for (CNode input : cnode.getInput()) {
childs.append(childAdded ? "," : "");
childs.append(input.getID());
childAdded = true;
}
childs.append(")");
if (childAdded)
sb.append(childs.toString());
}
sb.append('\n');
cnode.setVisited();
return sb.toString();
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class CPlanCSERewriter method rSetStrictDataNodeComparision.
private void rSetStrictDataNodeComparision(CNode current, boolean flag) {
// avoid redundant re-evaluation
if (current.isVisited())
return;
// process inputs recursively and node itself
for (CNode input : current.getInput()) {
rSetStrictDataNodeComparision(input, flag);
input.resetHash();
}
if (current instanceof CNodeData)
((CNodeData) current).setStrictEquals(flag);
current.setVisited();
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class CPlanCSERewriter method eliminateCommonSubexpressions.
public CNodeTpl eliminateCommonSubexpressions(CNodeTpl tpl) {
// Note: Compared to our traditional common subexpression elimination, on cplans,
// we don't have any parent references, and hence cannot use a collect-merge approach.
// In contrast, we exploit the hash signatures of cnodes as used in the plan cache.
// However, note that these signatures ignore input hops by default (for better plan
// cache hit rates), but are temporarily set to strict evaluation for this rewrite.
List<CNode> outputs = (tpl instanceof CNodeMultiAgg) ? ((CNodeMultiAgg) tpl).getOutputs() : Collections.singletonList(tpl.getOutput());
// step 1: set data nodes to strict comparison
tpl.resetVisitStatusOutputs();
for (CNode out : outputs) rSetStrictDataNodeComparision(out, true);
// step 2: perform common subexpression elimination
HashMap<CNode, CNode> cseSet = new HashMap<>();
tpl.resetVisitStatusOutputs();
for (CNode out : outputs) rEliminateCommonSubexpression(out, cseSet);
// step 3: reset data nodes to imprecise comparison
tpl.resetVisitStatusOutputs();
for (CNode out : outputs) rSetStrictDataNodeComparision(out, false);
tpl.resetVisitStatusOutputs();
return tpl;
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class TemplateCell method constructCplan.
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
// recursively process required cplan output
HashSet<Hop> inHops = new HashSet<>();
HashMap<Long, CNode> tmp = new HashMap<>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
hop.resetVisitStatus();
// reorder inputs (ensure matrices/vectors come first) and prune literals
// note: we order by number of cells and subsequently sparsity to ensure
// that sparse inputs are used as the main input w/o unnecessary conversion
Hop[] sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).toArray(Hop[]::new);
// construct template node
ArrayList<CNode> inputs = new ArrayList<>();
for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeCell tpl = new CNodeCell(inputs, output);
tpl.setCellType(TemplateUtils.getCellType(hop));
tpl.setAggOp(TemplateUtils.getAggOp(hop));
tpl.setSparseSafe(isSparseSafe(Arrays.asList(hop), sinHops[0], Arrays.asList(tpl.getOutput()), Arrays.asList(tpl.getAggOp()), false));
tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
tpl.setBeginLine(hop.getBeginLine());
// return cplan instance
return new Pair<>(sinHops, tpl);
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class TemplateCell method rConstructCplan.
protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
// recursively process required childs
if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
tmp.put(hop.getHopID(), cdata);
inHops.add(hop);
return;
}
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
rConstructCplan(c, memo, tmp, inHops, compileLiterals);
else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
i == 0)
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
} else if (hop instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
String primitiveOpName = bop.getOp().name();
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
// construct binary cnode
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
// correct indexing types of existing lookups
if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
TemplateUtils.rFlipVectorLookups(out);
// maintain input hops
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof AggUnaryOp) {
// aggregation handled in template implementation (note: we do not compile
// ^2 of SUM_SQ into the operator to simplify the detection of single operators)
out = tmp.get(hop.getInput().get(0).getHopID());
} else if (hop instanceof AggBinaryOp) {
// (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
out = new CNodeUnary(cdata1, UnaryType.POW2);
} else {
CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
inHops.add(hop.getInput().get(0).getInput().get(0));
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
}
}
tmp.put(hop.getHopID(), out);
}
Aggregations