use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBasedV2 method pruneInvalidAndSpecialCasePlans.
private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) {
// prune invalid row entries w/ violated blocksize constraint
if (OptimizerUtils.isSparkExecutionMode()) {
for (Long hopID : part.getPartition()) {
if (!memo.contains(hopID, TemplateType.ROW))
continue;
Hop hop = memo.getHopRefs().get(hopID);
boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop, true) > OptimizerUtils.getLocalMemBudget();
boolean validNcol = hop.getDataType().isScalar() || (HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim1() <= hop.getRowsInBlock() : hop.getDim2() <= hop.getColsInBlock());
for (Hop in : hop.getInput()) validNcol &= in.getDataType().isScalar() || (in.getDim2() <= in.getColsInBlock()) || (hop instanceof AggBinaryOp && in.getDim1() <= in.getRowsInBlock() && HopRewriteUtils.isTransposeOperation(in));
if (isSpark && !validNcol) {
List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW);
memo.removeAllRefTo(hopID, TemplateType.ROW);
if (LOG.isTraceEnabled()) {
LOG.trace("Removed row memo table entries w/ violated blocksize constraint (" + hopID + "): " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
}
}
}
}
// prune row aggregates with pure cellwise operations
// (we determine a blacklist of all operators in a partition that either
// depend upon row aggregates or on which row aggregates depend)
HashSet<Long> blacklist = collectIrreplaceableRowOps(memo, part);
for (Long hopID : part.getPartition()) {
if (blacklist.contains(hopID))
continue;
MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
if (me != null && me.type == TemplateType.ROW && memo.hasOnlyExactMatches(hopID, TemplateType.ROW, TemplateType.CELL)) {
List<MemoTableEntry> rmList = memo.get(hopID, TemplateType.ROW);
memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(rmList));
if (LOG.isTraceEnabled()) {
LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(rmList.toArray(new MemoTableEntry[0])));
}
}
}
// we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
for (Long hopID : part.getPartition()) {
if (memo.countEntries(hopID, TemplateType.OUTER) == 2) {
List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
MemoTableEntry me1 = entries.get(0);
MemoTableEntry me2 = entries.get(1);
MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
if (rmEntry != null) {
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
if (LOG.isTraceEnabled())
LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
}
}
}
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBasedV2 method rPruneInvalidPlans.
private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) {
// memoization (not via hops because in middle of dag)
if (visited.contains(current.getHopID()))
return;
// process children recursively
for (Hop c : current.getInput()) rPruneInvalidPlans(memo, c, visited, part, plan);
// find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs,
// i.e., plans that become invalid after the previous pruning step
long hopID = current.getHopID();
if (part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW)) {
Iterator<MemoTableEntry> iter = memo.get(hopID, TemplateType.ROW).iterator();
while (iter.hasNext()) {
MemoTableEntry me = iter.next();
// convert leaf node with pure vector inputs
boolean applyLeaf = (!me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current));
// convert inner node without row template input
boolean applyInner = !applyLeaf && !ROW_TPL.open(current);
for (int i = 0; i < 3 & applyInner; i++) if (me.isPlanRef(i))
applyInner &= !memo.contains(me.input(i), TemplateType.ROW);
if (applyLeaf || applyInner) {
String type = applyLeaf ? "leaf" : "inner";
if (isValidRow2CellOp(current)) {
me.type = TemplateType.CELL;
if (LOG.isTraceEnabled())
LOG.trace("Converted " + type + " memo table entry from row to cell: " + me);
} else {
if (LOG.isTraceEnabled())
LOG.trace("Removed " + type + " memo table entry row (unsupported cell): " + me);
iter.remove();
}
}
}
}
visited.add(current.getHopID());
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class TemplateMultiAgg method constructCplan.
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
// get all root nodes for multi aggregation
MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MAGG);
ArrayList<Hop> roots = new ArrayList<>();
for (int i = 0; i < 3; i++) if (multiAgg.isPlanRef(i))
roots.add(memo._hopRefs.get(multiAgg.input(i)));
Hop.resetVisitStatus(roots);
// recursively process required cplan outputs
HashSet<Hop> inHops = new HashSet<>();
HashMap<Long, CNode> tmp = new HashMap<>();
for (// use celltpl cplan construction
Hop root : // use celltpl cplan construction
roots) super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
Hop.resetVisitStatus(roots);
// reorder inputs (ensure matrices/vectors come first) and prune literals
// note: we order by number of cells and subsequently sparsity to ensure
// that sparse inputs are used as the main input w/o unnecessary conversion
Hop shared = getSparseSafeSharedInput(roots, inHops);
Hop[] sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator(shared)).toArray(Hop[]::new);
// construct template node
ArrayList<CNode> inputs = new ArrayList<>();
for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
ArrayList<CNode> outputs = new ArrayList<>();
ArrayList<AggOp> aggOps = new ArrayList<>();
for (Hop root : roots) {
CNode node = tmp.get(root.getHopID());
if (// add indexing ops for sideways data inputs
node instanceof CNodeData && ((CNodeData) inputs.get(0)).getHopID() != ((CNodeData) node).getHopID())
node = new CNodeUnary(node, (roots.get(0).getDim2() == 1) ? UnaryType.LOOKUP_R : UnaryType.LOOKUP_RC);
outputs.add(node);
aggOps.add(TemplateUtils.getAggOp(root));
}
CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
tpl.setAggOps(aggOps);
tpl.setSparseSafe(isSparseSafe(roots, sinHops[0], tpl.getOutputs(), tpl.getAggOps(), true));
tpl.setRootNodes(roots);
tpl.setBeginLine(hop.getBeginLine());
// return cplan instance
return new Pair<>(sinHops, tpl);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class TemplateRow method rConstructCplan.
private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
// memoization for common subexpression elimination and to avoid redundant work
if (tmp.containsKey(hop.getHopID()))
return;
// recursively process required childs
MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
if (me != null && me.isPlanRef(i))
rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
else {
CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
}
// construct cnode for current hop
CNode out = null;
if (hop instanceof AggUnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
if (hop.getInput().get(0).getDim2() == 1)
out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
else {
String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
}
} else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
// vector add without temporary copy
if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
else
out = cdata1;
} else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
}
} else if (hop instanceof AggBinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
// correct input under transpose
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
inHops.remove(hop.getInput().get(0));
if (cdata1 instanceof CNodeData)
inHops.add(hop.getInput().get(0).getInput().get(0));
// note: vectorMultAdd applicable to vector-scalar, and vector-vector
if (hop.getInput().get(1).getDim2() == 1)
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
if (!inHops2.containsKey("B1")) {
// incl modification of X for consistency
if (cdata1 instanceof CNodeData)
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
if (!inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0).getInput().get(0));
} else {
if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
else if (hop.getInput().get(1).getDim2() == 1) {
out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
inHops2.put("X", hop.getInput().get(0));
} else {
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
inHops2.put("X", hop.getInput().get(0));
inHops2.put("B1", hop.getInput().get(1));
}
}
} else if (HopRewriteUtils.isTransposeOperation(hop)) {
out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
inHops.add(hop.getInput().get(0));
} else if (hop instanceof UnaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
} else
throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
} else // general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
String primitiveOpName = ((UnaryOp) hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
} else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
// special case for cbind with zeros
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = null;
if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
// rm 0-matrix
inHops.remove(hop.getInput().get(1));
} else {
cdata2 = tmp.get(hop.getInput().get(1).getHopID());
cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
}
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", hop.getInput().get(0));
} else if (hop instanceof BinaryOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
(hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
} else {
String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if (TemplateUtils.isColVector(cdata2))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
}
if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
inHops2.put("X", hop.getInput().get(0));
}
} else
throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
} else // one input is a vector/scalar other is a scalar
{
String primitiveOpName = ((BinaryOp) hop).getOp().toString();
if (TemplateUtils.isColVector(cdata1))
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if (// vector or vector can be inferred from lhs
TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
}
} else if (hop instanceof TernaryOp) {
TernaryOp top = (TernaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
// add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
// construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
} else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
CNode[] inputs = new CNode[hop.getInput().size()];
for (int i = 0; i < hop.getInput().size(); i++) {
Hop c = hop.getInput().get(i);
CNode cdata = tmp.get(c.getHopID());
if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
inputs[i] = cdata;
if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
inHops2.put("X", c);
}
out = new CNodeNary(inputs, NaryType.VECT_CBIND);
} else if (hop instanceof ParameterizedBuiltinOp) {
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
} else if (hop instanceof IndexingOp) {
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
}
if (out == null) {
throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
}
if (out.getDataType().isMatrix()) {
out.setNumRows(hop.getDim1());
out.setNumCols(hop.getDim2());
}
tmp.put(hop.getHopID(), out);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class SpoofCompiler method rExploreCPlans.
// //////////////////
// Codegen plan construction
private static void rExploreCPlans(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
// top-down memoization of processed dag nodes
if (memo.contains(hop.getHopID()) || memo.containsHop(hop))
return;
// recursive candidate exploration
for (Hop c : hop.getInput()) rExploreCPlans(c, memo, compileLiterals);
// open initial operator plans, if possible
for (TemplateBase tpl : TemplateUtils.TEMPLATES) if (tpl.open(hop))
memo.addAll(hop, enumPlans(hop, null, tpl, memo));
// fuse and merge operator plans
for (Hop c : hop.getInput()) for (TemplateBase tpl : memo.getDistinctTemplates(c.getHopID())) if (tpl.fuse(hop, c))
memo.addAll(hop, enumPlans(hop, c, tpl, memo));
// close operator plans, if required
if (memo.contains(hop.getHopID())) {
Iterator<MemoTableEntry> iter = memo.get(hop.getHopID()).iterator();
while (iter.hasNext()) {
MemoTableEntry me = iter.next();
TemplateBase tpl = TemplateUtils.createTemplate(me.type);
CloseType ccode = tpl.close(hop);
if (ccode == CloseType.CLOSED_INVALID)
iter.remove();
me.ctype = ccode;
}
}
// prune subsumed / redundant plans
if (PRUNE_REDUNDANT_PLANS) {
memo.pruneRedundant(hop.getHopID(), PLAN_SEL_POLICY.isHeuristic(), null);
}
// mark visited even if no plans found (e.g., unsupported ops)
memo.addHop(hop);
}
Aggregations