use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.
//within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
//create index of plans that reference full aggregates to avoid circular dependencies
HashSet<Long> refHops = new HashSet<Long>();
for (Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet()) if (!e.getValue().isEmpty()) {
Hop hop = memo._hopRefs.get(e.getKey());
for (Hop c : hop.getInput()) refHops.add(c.getHopID());
}
//find all full aggregations (the fact that they are in the same partition guarantees
//that they also have common subexpressions, also full aggregations are by def root nodes)
ArrayList<Long> fullAggs = new ArrayList<Long>();
for (Long hopID : R) {
Hop root = memo._hopRefs.get(hopID);
if (!refHops.contains(hopID) && root instanceof AggUnaryOp && ((AggUnaryOp) root).getDirection() == Direction.RowCol)
fullAggs.add(hopID);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
//construct and add multiagg template plans (w/ max 3 aggregations)
for (int i = 0; i < fullAggs.size(); i += 3) {
int ito = Math.min(i + 3, fullAggs.size());
if (ito - i >= 2) {
MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1);
if (isValidMultiAggregate(memo, me)) {
for (int j = i; j < ito; j++) {
memo.add(memo._hopRefs.get(fullAggs.get(j)), me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
}
} else if (LOG.isTraceEnabled()) {
LOG.trace("Removed invalid multiagg plan: " + me);
}
}
}
}
Aggregations