use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.
// within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
// create index of plans that reference full aggregates to avoid circular dependencies
HashSet<Long> refHops = new HashSet<>();
for (Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet()) if (!e.getValue().isEmpty()) {
Hop hop = memo.getHopRefs().get(e.getKey());
for (Hop c : hop.getInput()) refHops.add(c.getHopID());
}
// find all full aggregations (the fact that they are in the same partition guarantees
// that they also have common subexpressions, also full aggregations are by def root nodes)
ArrayList<Long> fullAggs = new ArrayList<>();
for (Long hopID : R) {
Hop root = memo.getHopRefs().get(hopID);
if (!refHops.contains(hopID) && isMultiAggregateRoot(root))
fullAggs.add(hopID);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
// construct and add multiagg template plans (w/ max 3 aggregations)
for (int i = 0; i < fullAggs.size(); i += 3) {
int ito = Math.min(i + 3, fullAggs.size());
if (ito - i >= 2) {
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1, ito - i);
if (isValidMultiAggregate(memo, me)) {
for (int j = i; j < ito; j++) {
memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
}
} else if (LOG.isTraceEnabled()) {
LOG.trace("Removed invalid multiagg plan: " + me);
}
}
}
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method rPruneSuboptimalPlans.
private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
// memoization (not via hops because in middle of dag)
if (visited.contains(current.getHopID()))
return;
// remove memo table entries if necessary
long hopID = current.getHopID();
if (partition.contains(hopID) && memo.contains(hopID)) {
Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
while (iter.hasNext()) {
MemoTableEntry me = iter.next();
if (!hasNoRefToMaterialization(me, M, plan) && me.type != TemplateType.OUTER) {
iter.remove();
if (LOG.isTraceEnabled())
LOG.trace("Removed memo table entry: " + me);
}
}
}
// process children recursively
for (Hop c : current.getInput()) rPruneSuboptimalPlans(memo, c, visited, partition, M, plan);
visited.add(current.getHopID());
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBasedV2 method rGetPlanCosts.
private double rGetPlanCosts(CPlanMemoTable memo, final Hop current, HashSet<VisitMarkCost> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType, final double costBound) {
final long currentHopId = current.getHopID();
// costs of complex operation DAGs within a single fused operator
if (!visited.add(new VisitMarkCost(currentHopId, (costsCurrent == null || currentType == TemplateType.MAGG) ? -1 : costsCurrent.ID)))
// already existing
return 0;
// open template if necessary, including memoization
// under awareness of current plan choice
MemoTableEntry best = null;
boolean opened = (currentType == null);
if (memo.contains(currentHopId)) {
// use streams, lambda expressions, etc to avoid unnecessary overhead
if (currentType == null) {
for (MemoTableEntry me : memo.get(currentHopId)) best = me.isValid() && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && BasicPlanComparator.icompare(me, best) < 0 ? me : best;
opened = true;
} else {
for (MemoTableEntry me : memo.get(currentHopId)) best = (me.type == currentType || me.type == TemplateType.CELL) && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && TypedPlanComparator.icompare(me, best, currentType) < 0 ? me : best;
}
}
// create new cost vector if opened, initialized with write costs
CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current));
double costs = 0;
// add other roots for multi-agg template to account for shared costs
if (opened && best != null && best.type == TemplateType.MAGG) {
// account costs to first multi-agg root
if (best.input1 == currentHopId)
for (int i = 1; i < 3; i++) {
if (!best.isPlanRef(i))
continue;
costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG, costBound - costs);
if (costs >= costBound)
return Double.POSITIVE_INFINITY;
}
else
return 0;
}
// add compute costs of current operator to costs vector
costVect.computeCosts += computeCosts.get(currentHopId);
// process children recursively
for (int i = 0; i < current.getInput().size(); i++) {
Hop c = current.getInput().get(i);
if (best != null && best.isPlanRef(i))
costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type, costBound - costs);
else if (best != null && isImplicitlyFused(current, i, best.type))
costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
else {
// include children and I/O costs
if (part.getPartition().contains(c.getHopID()))
costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
if (costVect != null && c.getDataType().isMatrix())
costVect.addInputSize(c.getHopID(), getSize(c));
}
if (costs >= costBound)
return Double.POSITIVE_INFINITY;
}
// add costs for opened fused operator
if (opened) {
double memInputs = sumInputMemoryEstimates(memo, costVect);
double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM + Math.max(memInputs / READ_BANDWIDTH_MEM, costVect.computeCosts / COMPUTE_BANDWIDTH);
// read correction for distributed computation
if (memInputs > OptimizerUtils.getLocalMemBudget())
tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST;
// sparsity correction for outer-product template (and sparse-safe cell)
Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
if (best != null && best.type == TemplateType.OUTER)
tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
else // write correction for known evictions in CP
if (memInputs <= OptimizerUtils.getLocalMemBudget() && sumTmpInputOutputSize(memo, costVect) * 8 > LazyWriteBuffer.getWriteBufferLimit())
tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO;
costs += tmpCosts;
if (LOG.isTraceEnabled()) {
String type = (best != null) ? best.type.name() : "HOP";
LOG.trace("Cost vector (" + type + " " + currentHopId + "): " + costVect + " -> " + tmpCosts);
}
} else // add costs for non-partition read in the middle of fused operator
if (part.getExtConsumed().contains(current.getHopID())) {
costs += rGetPlanCosts(memo, current, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
}
// sanity check non-negative costs
if (costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs))
throw new RuntimeException("Wrong cost estimate: " + costs);
return costs;
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBasedV2 method createAndAddMultiAggPlans.
// within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
// create index of plans that reference full aggregates to avoid circular dependencies
HashSet<Long> refHops = new HashSet<>();
for (Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet()) if (!e.getValue().isEmpty()) {
Hop hop = memo.getHopRefs().get(e.getKey());
for (Hop c : hop.getInput()) refHops.add(c.getHopID());
}
// find all full aggregations (the fact that they are in the same partition guarantees
// that they also have common subexpressions, also full aggregations are by def root nodes)
ArrayList<Long> fullAggs = new ArrayList<>();
for (Long hopID : R) {
Hop root = memo.getHopRefs().get(hopID);
if (!refHops.contains(hopID) && isMultiAggregateRoot(root))
fullAggs.add(hopID);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
// construct and add multiagg template plans (w/ max 3 aggregations)
for (int i = 0; i < fullAggs.size(); i += 3) {
int ito = Math.min(i + 3, fullAggs.size());
if (ito - i >= 2) {
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1, ito - i);
if (isValidMultiAggregate(memo, me)) {
for (int j = i; j < ito; j++) {
memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
}
} else if (LOG.isTraceEnabled()) {
LOG.trace("Removed invalid multiagg plan: " + me);
}
}
}
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.
the class PlanSelectionFuseCostBasedV2 method rCollectDependentRowOps.
private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part, HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp) {
// avoid redundant evaluation of processed and non-partition nodes
Pair<Long, Integer> key = Pair.of(hop.getHopID(), (foundRowOp ? Short.MAX_VALUE : 0) + ((type != null) ? type.ordinal() + 1 : 0));
if (visited.contains(key) || !part.getPartition().contains(hop.getHopID())) {
return;
}
// process node itself (top-down)
MemoTableEntry me = (type == null) ? memo.getBest(hop.getHopID()) : memo.getBest(hop.getHopID(), type);
boolean inRow = (me != null && me.type == TemplateType.ROW && type == TemplateType.ROW);
boolean diffPlans = // guard against plan differences
part.getMatPointsExt().length > 0 && memo.contains(hop.getHopID(), TemplateType.ROW) && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
if (inRow && foundRowOp)
blacklist.add(hop.getHopID());
if (isRowAggOp(hop, inRow) || diffPlans) {
blacklist.add(hop.getHopID());
foundRowOp = true;
}
// process children recursively
for (int i = 0; i < hop.getInput().size(); i++) {
boolean lfoundRowOp = foundRowOp && me != null && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type));
rCollectDependentRowOps(hop.getInput().get(i), memo, part, blacklist, visited, me != null ? me.type : null, lfoundRowOp);
}
// process node itself (bottom-up)
if (!blacklist.contains(hop.getHopID())) {
for (int i = 0; i < hop.getInput().size(); i++) if (me != null && me.type == TemplateType.ROW && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)) && blacklist.contains(hop.getInput().get(i).getHopID())) {
blacklist.add(hop.getHopID());
}
}
visited.add(key);
}
Aggregations