use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseNoRedundancy method rSelectPlans.
private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) {
if (isVisited(current.getHopID(), currentType))
return;
//step 0: remove plans that refer to a common partial plan
if (memo.contains(current.getHopID())) {
HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
List<MemoTableEntry> hopP = memo.get(current.getHopID());
for (MemoTableEntry e1 : hopP) for (int i = 0; i < 3; i++) if (e1.isPlanRef(i) && current.getInput().get(i).getParent().size() > 1)
//remove references to hops w/ multiple consumers
rmSet.add(e1);
memo.remove(current, rmSet);
}
//step 1: prune subsumed plans of same type
if (memo.contains(current.getHopID())) {
HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
List<MemoTableEntry> hopP = memo.get(current.getHopID());
for (MemoTableEntry e1 : hopP) for (MemoTableEntry e2 : hopP) if (e1 != e2 && e1.subsumes(e2))
rmSet.add(e2);
memo.remove(current, rmSet);
}
//step 2: select plan for current path
MemoTableEntry best = null;
if (memo.contains(current.getHopID())) {
if (currentType == null) {
best = memo.get(current.getHopID()).stream().filter(p -> isValid(p, current)).min(new BasicPlanComparator()).orElse(null);
} else {
best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CellTpl).min(Comparator.comparing(p -> 7 - ((p.type == currentType) ? 4 : 0) - p.countPlanRefs())).orElse(null);
}
addBestPlan(current.getHopID(), best);
}
//step 3: recursively process children
for (int i = 0; i < current.getInput().size(); i++) {
TemplateType pref = (best != null && best.isPlanRef(i)) ? best.type : null;
rSelectPlans(memo, current.getInput().get(i), pref);
}
setVisited(current.getHopID(), currentType);
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rPruneInvalidPlans.
private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
//memoization (not via hops because in middle of dag)
if (visited.contains(current.getHopID()))
return;
//process children recursively
for (Hop c : current.getInput()) rPruneInvalidPlans(memo, c, visited, partition, M, plan);
//find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs,
//i.e., plans that become invalid after the previous pruning step
long hopID = current.getHopID();
if (partition.contains(hopID) && memo.contains(hopID, TemplateType.RowTpl)) {
for (MemoTableEntry me : memo.get(hopID)) {
if (me.type == TemplateType.RowTpl) {
//convert leaf node with pure vector inputs
if (!me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current)) {
me.type = TemplateType.CellTpl;
if (LOG.isTraceEnabled())
LOG.trace("Converted leaf memo table entry from row to cell: " + me);
}
//convert inner node without row template input
if (me.hasPlanRef() && !ROW_TPL.open(current)) {
boolean hasRowInput = false;
for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
hasRowInput |= memo.contains(me.input(i), TemplateType.RowTpl);
if (!hasRowInput) {
me.type = TemplateType.CellTpl;
if (LOG.isTraceEnabled())
LOG.trace("Converted inner memo table entry from row to cell: " + me);
}
}
}
}
}
visited.add(current.getHopID());
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method getPartitionRootNodes.
private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, HashSet<Long> partition) {
//build inverted index of references entries
HashSet<Long> ix = new HashSet<Long>();
for (Long hopID : partition) if (memo.contains(hopID))
for (MemoTableEntry me : memo.get(hopID)) {
ix.add(me.input1);
ix.add(me.input2);
ix.add(me.input3);
}
HashSet<Long> roots = new HashSet<Long>();
for (Long hopID : partition) if (!ix.contains(hopID))
roots.add(hopID);
return roots;
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.
//within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
//create index of plans that reference full aggregates to avoid circular dependencies
HashSet<Long> refHops = new HashSet<Long>();
for (Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet()) if (!e.getValue().isEmpty()) {
Hop hop = memo._hopRefs.get(e.getKey());
for (Hop c : hop.getInput()) refHops.add(c.getHopID());
}
//find all full aggregations (the fact that they are in the same partition guarantees
//that they also have common subexpressions, also full aggregations are by def root nodes)
ArrayList<Long> fullAggs = new ArrayList<Long>();
for (Long hopID : R) {
Hop root = memo._hopRefs.get(hopID);
if (!refHops.contains(hopID) && root instanceof AggUnaryOp && ((AggUnaryOp) root).getDirection() == Direction.RowCol)
fullAggs.add(hopID);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
//construct and add multiagg template plans (w/ max 3 aggregations)
for (int i = 0; i < fullAggs.size(); i += 3) {
int ito = Math.min(i + 3, fullAggs.size());
if (ito - i >= 2) {
MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1);
if (isValidMultiAggregate(memo, me)) {
for (int j = i; j < ito; j++) {
memo.add(memo._hopRefs.get(fullAggs.get(j)), me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
}
} else if (LOG.isTraceEnabled()) {
LOG.trace("Removed invalid multiagg plan: " + me);
}
}
}
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method selectPlans.
private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) {
//prune row aggregates with pure cellwise operations
for (Long hopID : R) {
MemoTableEntry me = memo.getBest(hopID, TemplateType.RowTpl);
if (me.type == TemplateType.RowTpl && memo.contains(hopID, TemplateType.CellTpl) && rIsRowTemplateWithoutAgg(memo, memo._hopRefs.get(hopID), new HashSet<Long>())) {
List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.RowTpl);
memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(blacklist));
if (LOG.isTraceEnabled()) {
LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
}
}
}
//we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
for (Long hopID : partition) {
if (memo.countEntries(hopID, TemplateType.OuterProdTpl) == 2) {
List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OuterProdTpl);
MemoTableEntry me1 = entries.get(0);
MemoTableEntry me2 = entries.get(1);
MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
if (rmEntry != null) {
memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(Arrays.asList(rmEntry)));
memo._plansBlacklist.remove(rmEntry.input(rmEntry.getPlanRefIndex()));
if (LOG.isTraceEnabled())
LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
}
}
}
//if no materialization points, use basic fuse-all w/ partition awareness
if (M == null || M.isEmpty()) {
for (Long hopID : R) rSelectPlansFuseAll(memo, memo._hopRefs.get(hopID), null, partition);
} else {
//TODO branch and bound pruning, right now we use exhaustive enum for early experiments
//via skip ahead in below enumeration algorithm
//obtain hop compute costs per cell once
HashMap<Long, Double> computeCosts = new HashMap<Long, Double>();
for (Long hopID : R) rGetComputeCosts(memo._hopRefs.get(hopID), partition, computeCosts);
//scan linearized search space, w/ skips for branch and bound pruning
int len = (int) Math.pow(2, M.size());
boolean[] bestPlan = null;
double bestC = Double.MAX_VALUE;
for (int i = 0; i < len; i++) {
//construct assignment
boolean[] plan = createAssignment(M.size(), i);
//cost assignment on hops
double C = getPlanCost(memo, partition, R, M, plan, computeCosts);
if (LOG.isTraceEnabled())
LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C);
//cost comparisons
if (bestPlan == null || C < bestC) {
bestC = C;
bestPlan = plan;
if (LOG.isTraceEnabled())
LOG.trace("Enum: Found new best plan.");
}
}
//prune memo table wrt best plan and select plans
HashSet<Long> visited = new HashSet<Long>();
for (Long hopID : R) rPruneSuboptimalPlans(memo, memo._hopRefs.get(hopID), visited, partition, M, bestPlan);
HashSet<Long> visited2 = new HashSet<Long>();
for (Long hopID : R) rPruneInvalidPlans(memo, memo._hopRefs.get(hopID), visited2, partition, M, bestPlan);
for (Long hopID : R) rSelectPlansFuseAll(memo, memo._hopRefs.get(hopID), null, partition);
}
}
Aggregations