use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method selectPlans.
private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) {
// prune row aggregates with pure cellwise operations
for (Long hopID : R) {
MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
if (me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist));
if (LOG.isTraceEnabled()) {
LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
}
}
}
// we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
for (Long hopID : partition) {
if (memo.countEntries(hopID, TemplateType.OUTER) == 2) {
List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
MemoTableEntry me1 = entries.get(0);
MemoTableEntry me2 = entries.get(1);
MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
if (rmEntry != null) {
memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
if (LOG.isTraceEnabled())
LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
}
}
}
// if no materialization points, use basic fuse-all w/ partition awareness
if (M == null || M.isEmpty()) {
for (Long hopID : R) rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, partition);
} else {
// TODO branch and bound pruning, right now we use exhaustive enum for early experiments
// via skip ahead in below enumeration algorithm
// obtain hop compute costs per cell once
HashMap<Long, Double> computeCosts = new HashMap<>();
for (Long hopID : R) rGetComputeCosts(memo.getHopRefs().get(hopID), partition, computeCosts);
// scan linearized search space, w/ skips for branch and bound pruning
int len = (int) Math.pow(2, M.size());
boolean[] bestPlan = null;
double bestC = Double.MAX_VALUE;
for (int i = 0; i < len; i++) {
// construct assignment
boolean[] plan = createAssignment(M.size(), i);
// cost assignment on hops
double C = getPlanCost(memo, partition, R, M, plan, computeCosts);
if (LOG.isTraceEnabled())
LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C);
// cost comparisons
if (bestPlan == null || C < bestC) {
bestC = C;
bestPlan = plan;
if (LOG.isTraceEnabled())
LOG.trace("Enum: Found new best plan.");
}
}
if (DMLScript.STATISTICS) {
Statistics.incrementCodegenEnumAllP(len);
Statistics.incrementCodegenEnumEval(len);
}
// prune memo table wrt best plan and select plans
HashSet<Long> visited = new HashSet<>();
for (Long hopID : R) rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), visited, partition, M, bestPlan);
HashSet<Long> visited2 = new HashSet<>();
for (Long hopID : R) rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), visited2, partition, M, bestPlan);
for (Long hopID : R) rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, partition);
}
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method isRowTemplateWithoutAgg.
private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
// consider all aggregations other than root operation
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
boolean ret = true;
for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
return ret;
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rIsRowTemplateWithoutAgg.
private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
if (visited.contains(current.getHopID()))
return true;
boolean ret = true;
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
visited.add(current.getHopID());
return ret;
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rGetPlanCosts.
private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Pair<Long, Long>> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType) {
// memoization per hop id and cost vector to account for redundant
// computation without double counting materialized results or compute
// costs of complex operation DAGs within a single fused operator
Pair<Long, Long> tag = Pair.of(current.getHopID(), (costsCurrent == null) ? 0 : costsCurrent.ID);
if (visited.contains(tag))
return 0;
visited.add(tag);
// open template if necessary, including memoization
// under awareness of current plan choice
MemoTableEntry best = null;
boolean opened = false;
if (memo.contains(current.getHopID())) {
if (currentType == null) {
best = memo.get(current.getHopID()).stream().filter(p -> p.isValid()).filter(p -> hasNoRefToMaterialization(p, M, plan)).min(new BasicPlanComparator()).orElse(null);
opened = true;
} else {
best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CELL).filter(p -> hasNoRefToMaterialization(p, M, plan)).min(Comparator.comparing(p -> 7 - ((p.type == currentType) ? 4 : 0) - p.countPlanRefs())).orElse(null);
}
}
// create new cost vector if opened, initialized with write costs
CostVector costVect = !opened ? costsCurrent : new CostVector(Math.max(current.getDim1(), 1) * Math.max(current.getDim2(), 1));
// add compute costs of current operator to costs vector
if (partition.contains(current.getHopID()))
costVect.computeCosts += computeCosts.get(current.getHopID());
// process children recursively
double costs = 0;
for (int i = 0; i < current.getInput().size(); i++) {
Hop c = current.getInput().get(i);
if (best != null && best.isPlanRef(i))
costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, costVect, best.type);
else if (best != null && isImplicitlyFused(current, i, best.type))
costVect.addInputSize(c.getInput().get(0).getHopID(), Math.max(c.getDim1(), 1) * Math.max(c.getDim2(), 1));
else {
// include children and I/O costs
costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, null, null);
if (costVect != null && c.getDataType().isMatrix())
costVect.addInputSize(c.getHopID(), Math.max(c.getDim1(), 1) * Math.max(c.getDim2(), 1));
}
}
// add costs for opened fused operator
if (partition.contains(current.getHopID())) {
if (opened) {
if (LOG.isTraceEnabled())
LOG.trace("Cost vector for fused operator (hop " + current.getHopID() + "): " + costVect);
// time for output write
costs += costVect.outSize * 8 / WRITE_BANDWIDTH;
costs += Math.max(costVect.computeCosts * costVect.getMaxInputSize() / COMPUTE_BANDWIDTH, costVect.getSumInputSizes() * 8 / READ_BANDWIDTH);
} else // add costs for non-partition read in the middle of fused operator
if (hasNonPartitionConsumer(current, partition)) {
costs += rGetPlanCosts(memo, current, visited, partition, M, plan, computeCosts, null, null);
}
}
// sanity check non-negative costs
if (costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs))
throw new RuntimeException("Wrong cost estimate: " + costs);
return costs;
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBasedV2 method createAndAddMultiAggPlans.
// across-partition multi-agg templates with shared reads
private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
// collect full aggregations as initial set of candidates
HashSet<Long> fullAggs = new HashSet<>();
Hop.resetVisitStatus(roots);
for (Hop hop : roots) rCollectFullAggregates(hop, fullAggs);
Hop.resetVisitStatus(roots);
// remove operators with assigned multi-agg plans
fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
// check applicability for further analysis
if (fullAggs.size() <= 1)
return;
if (LOG.isTraceEnabled()) {
LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
// collect information for all candidates
// (subsumed aggregations, and inputs to fused operators)
List<AggregateInfo> aggInfos = new ArrayList<>();
for (Long hopID : fullAggs) {
Hop aggHop = memo.getHopRefs().get(hopID);
AggregateInfo tmp = new AggregateInfo(aggHop);
for (int i = 0; i < aggHop.getInput().size(); i++) {
Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i == 0 ? aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
}
if (tmp._fusedInputs.isEmpty()) {
if (HopRewriteUtils.isMatrixMultiply(aggHop)) {
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
} else
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
}
aggInfos.add(tmp);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
for (AggregateInfo info : aggInfos) LOG.trace(info);
}
// sort aggregations by num dependencies to simplify merging
// clusters of aggregations with parallel dependencies
aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
// greedy grouping of multi-agg candidates
boolean converged = false;
while (!converged) {
AggregateInfo merged = null;
for (int i = 0; i < aggInfos.size(); i++) {
AggregateInfo current = aggInfos.get(i);
for (int j = i + 1; j < aggInfos.size(); j++) {
AggregateInfo that = aggInfos.get(j);
if (current.isMergable(that)) {
merged = current.merge(that);
aggInfos.remove(j);
j--;
}
}
}
converged = (merged == null);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Merged across-partition ua(RC) aggregation info: ");
for (AggregateInfo info : aggInfos) LOG.trace(info);
}
// construct and add multiagg template plans (w/ max 3 aggregations)
for (AggregateInfo info : aggInfos) {
if (info._aggregates.size() <= 1)
continue;
Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, aggs[0], aggs[1], (aggs.length > 2) ? aggs[2] : -1, aggs.length);
for (int i = 0; i < aggs.length; i++) {
memo.add(memo.getHopRefs().get(aggs[i]), me);
addBestPlan(aggs[i], me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg* plan: " + aggs[i] + " " + me);
}
}
}
Aggregations