use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseCostBased method isRowTemplateWithoutAgg.
private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
// consider all aggregations other than root operation
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
boolean ret = true;
for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
return ret;
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseCostBased method rPruneSuboptimalPlans.
private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
// memoization (not via hops because in middle of dag)
if (visited.contains(current.getHopID()))
return;
// remove memo table entries if necessary
long hopID = current.getHopID();
if (partition.contains(hopID) && memo.contains(hopID)) {
Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
while (iter.hasNext()) {
MemoTableEntry me = iter.next();
if (!hasNoRefToMaterialization(me, M, plan) && me.type != TemplateType.OUTER) {
iter.remove();
if (LOG.isTraceEnabled())
LOG.trace("Removed memo table entry: " + me);
}
}
}
// process children recursively
for (Hop c : current.getInput()) rPruneSuboptimalPlans(memo, c, visited, partition, M, plan);
visited.add(current.getHopID());
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseCostBasedV2 method rPruneInvalidPlans.
private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) {
// memoization (not via hops because in middle of dag)
if (visited.contains(current.getHopID()))
return;
// process children recursively
for (Hop c : current.getInput()) rPruneInvalidPlans(memo, c, visited, part, plan);
// find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs,
// i.e., plans that become invalid after the previous pruning step
long hopID = current.getHopID();
if (part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW)) {
Iterator<MemoTableEntry> iter = memo.get(hopID, TemplateType.ROW).iterator();
while (iter.hasNext()) {
MemoTableEntry me = iter.next();
// convert leaf node with pure vector inputs
boolean applyLeaf = (!me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current));
// convert inner node without row template input
boolean applyInner = !applyLeaf && !ROW_TPL.open(current);
for (int i = 0; i < 3 & applyInner; i++) if (me.isPlanRef(i))
applyInner &= !memo.contains(me.input(i), TemplateType.ROW);
if (applyLeaf || applyInner) {
String type = applyLeaf ? "leaf" : "inner";
if (isValidRow2CellOp(current)) {
me.type = TemplateType.CELL;
if (LOG.isTraceEnabled())
LOG.trace("Converted " + type + " memo table entry from row to cell: " + me);
} else {
if (LOG.isTraceEnabled())
LOG.trace("Removed " + type + " memo table entry row (unsupported cell): " + me);
iter.remove();
}
}
}
}
visited.add(current.getHopID());
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseCostBasedV2 method rPruneSuboptimalPlans.
private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) {
// memoization (not via hops because in middle of dag)
if (visited.contains(current.getHopID()))
return;
// remove memo table entries if necessary
long hopID = current.getHopID();
if (part.getPartition().contains(hopID) && memo.contains(hopID)) {
Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
while (iter.hasNext()) {
MemoTableEntry me = iter.next();
if (!hasNoRefToMatPoint(hopID, me, matPoints, plan) && me.type != TemplateType.OUTER) {
iter.remove();
if (LOG.isTraceEnabled())
LOG.trace("Removed memo table entry: " + me);
}
}
}
// process children recursively
for (Hop c : current.getInput()) rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan);
visited.add(current.getHopID());
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.
the class PlanSelectionFuseCostBasedV2 method createAndAddMultiAggPlans.
// across-partition multi-agg templates with shared reads
private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
// collect full aggregations as initial set of candidates
HashSet<Long> fullAggs = new HashSet<>();
Hop.resetVisitStatus(roots);
for (Hop hop : roots) rCollectFullAggregates(hop, fullAggs);
Hop.resetVisitStatus(roots);
// remove operators with assigned multi-agg plans
fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
// check applicability for further analysis
if (fullAggs.size() <= 1)
return;
if (LOG.isTraceEnabled()) {
LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
// collect information for all candidates
// (subsumed aggregations, and inputs to fused operators)
List<AggregateInfo> aggInfos = new ArrayList<>();
for (Long hopID : fullAggs) {
Hop aggHop = memo.getHopRefs().get(hopID);
AggregateInfo tmp = new AggregateInfo(aggHop);
for (int i = 0; i < aggHop.getInput().size(); i++) {
Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i == 0 ? aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
}
if (tmp._fusedInputs.isEmpty()) {
if (HopRewriteUtils.isMatrixMultiply(aggHop)) {
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
} else
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
}
aggInfos.add(tmp);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
for (AggregateInfo info : aggInfos) LOG.trace(info);
}
// sort aggregations by num dependencies to simplify merging
// clusters of aggregations with parallel dependencies
aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
// greedy grouping of multi-agg candidates
boolean converged = false;
while (!converged) {
AggregateInfo merged = null;
for (int i = 0; i < aggInfos.size(); i++) {
AggregateInfo current = aggInfos.get(i);
for (int j = i + 1; j < aggInfos.size(); j++) {
AggregateInfo that = aggInfos.get(j);
if (current.isMergable(that)) {
merged = current.merge(that);
aggInfos.remove(j);
j--;
}
}
}
converged = (merged == null);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Merged across-partition ua(RC) aggregation info: ");
for (AggregateInfo info : aggInfos) LOG.trace(info);
}
// construct and add multiagg template plans (w/ max 3 aggregations)
for (AggregateInfo info : aggInfos) {
if (info._aggregates.size() <= 1)
continue;
Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, aggs[0], aggs[1], (aggs.length > 2) ? aggs[2] : -1, aggs.length);
for (int i = 0; i < aggs.length; i++) {
memo.add(memo.getHopRefs().get(aggs[i]), me);
addBestPlan(aggs[i], me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg* plan: " + aggs[i] + " " + me);
}
}
}
Aggregations