use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.
the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.
// across-partition multi-agg templates with shared reads
private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
// collect full aggregations as initial set of candidates
HashSet<Long> fullAggs = new HashSet<>();
Hop.resetVisitStatus(roots);
for (Hop hop : roots) rCollectFullAggregates(hop, fullAggs);
Hop.resetVisitStatus(roots);
// remove operators with assigned multi-agg plans
fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
// check applicability for further analysis
if (fullAggs.size() <= 1)
return;
if (LOG.isTraceEnabled()) {
LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
}
// collect information for all candidates
// (subsumed aggregations, and inputs to fused operators)
List<AggregateInfo> aggInfos = new ArrayList<>();
for (Long hopID : fullAggs) {
Hop aggHop = memo.getHopRefs().get(hopID);
AggregateInfo tmp = new AggregateInfo(aggHop);
for (int i = 0; i < aggHop.getInput().size(); i++) {
Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i == 0 ? aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
}
if (tmp._fusedInputs.isEmpty()) {
if (HopRewriteUtils.isMatrixMultiply(aggHop)) {
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
} else
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
}
aggInfos.add(tmp);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
for (AggregateInfo info : aggInfos) LOG.trace(info);
}
// sort aggregations by num dependencies to simplify merging
// clusters of aggregations with parallel dependencies
aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
// greedy grouping of multi-agg candidates
boolean converged = false;
while (!converged) {
AggregateInfo merged = null;
for (int i = 0; i < aggInfos.size(); i++) {
AggregateInfo current = aggInfos.get(i);
for (int j = i + 1; j < aggInfos.size(); j++) {
AggregateInfo that = aggInfos.get(j);
if (current.isMergable(that)) {
merged = current.merge(that);
aggInfos.remove(j);
j--;
}
}
}
converged = (merged == null);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Merged across-partition ua(RC) aggregation info: ");
for (AggregateInfo info : aggInfos) LOG.trace(info);
}
// construct and add multiagg template plans (w/ max 3 aggregations)
for (AggregateInfo info : aggInfos) {
if (info._aggregates.size() <= 1)
continue;
Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, aggs[0], aggs[1], (aggs.length > 2) ? aggs[2] : -1, aggs.length);
for (int i = 0; i < aggs.length; i++) {
memo.add(memo.getHopRefs().get(aggs[i]), me);
addBestPlan(aggs[i], me);
if (LOG.isTraceEnabled())
LOG.trace("Added multiagg* plan: " + aggs[i] + " " + me);
}
}
}
use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.
the class SpoofCompiler method optimize.
/**
* Main interface of sum-product optimizer, statement block dag.
*
* @param roots dag root nodes
* @param recompile true if invoked during dynamic recompilation
* @return dag root nodes of modified dag
*/
public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean recompile) {
if (roots == null || roots.isEmpty())
return roots;
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
ArrayList<Hop> ret = roots;
try {
// context-sensitive literal replacement (only integers during recompile)
boolean compileLiterals = (PLAN_CACHE_POLICY == PlanCachePolicy.CONSTANT) || !recompile;
// candidate exploration of valid partial fusion plans
CPlanMemoTable memo = new CPlanMemoTable();
for (Hop hop : roots) rExploreCPlans(hop, memo, compileLiterals);
// candidate selection of optimal fusion plan
memo.pruneSuboptimal(roots);
// construct actual cplan representations
// note: we do not use the hop visit status due to jumps over fused operators which would
// corrupt subsequent resets, leaving partial hops dags in visited status
HashMap<Long, Pair<Hop[], CNodeTpl>> cplans = new LinkedHashMap<>();
HashSet<Long> visited = new HashSet<>();
for (Hop hop : roots) rConstructCPlans(hop, memo, cplans, compileLiterals, visited);
// cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping,
// remove empty templates with single cnodedata input, remove spurious lookups,
// perform common subexpression elimination)
cplans = cleanupCPlans(memo, cplans);
// explain before modification
if (LOG.isTraceEnabled() && !cplans.isEmpty()) {
// existing cplans
LOG.trace("Codegen EXPLAIN (before optimize): \n" + Explain.explainHops(roots));
}
// source code generation for all cplans
HashMap<Long, Pair<Hop[], Class<?>>> clas = new HashMap<>();
for (Entry<Long, Pair<Hop[], CNodeTpl>> cplan : cplans.entrySet()) {
Pair<Hop[], CNodeTpl> tmp = cplan.getValue();
Class<?> cla = planCache.getPlan(tmp.getValue());
if (cla == null) {
// generate java source code
String src = tmp.getValue().codegen(false);
// explain debug output cplans or generated source code
if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isHopsType(recompile)) {
LOG.info("Codegen EXPLAIN (generated cplan for HopID: " + cplan.getKey() + ", line " + tmp.getValue().getBeginLine() + ", hash=" + tmp.getValue().hashCode() + "):");
LOG.info(tmp.getValue().getClassname() + Explain.explainCPlan(cplan.getValue().getValue()));
}
if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isRuntimeType(recompile)) {
LOG.info("Codegen EXPLAIN (generated code for HopID: " + cplan.getKey() + ", line " + tmp.getValue().getBeginLine() + ", hash=" + tmp.getValue().hashCode() + "):");
LOG.info(src);
}
// compile generated java source code
cla = CodegenUtils.compileClass("codegen." + tmp.getValue().getClassname(), src);
// maintain plan cache
if (PLAN_CACHE_POLICY != PlanCachePolicy.NONE)
planCache.putPlan(tmp.getValue(), cla);
} else if (DMLScript.STATISTICS) {
Statistics.incrementCodegenOpCacheHits();
}
// make class available and maintain hits
if (cla != null)
clas.put(cplan.getKey(), new Pair<Hop[], Class<?>>(tmp.getKey(), cla));
if (DMLScript.STATISTICS)
Statistics.incrementCodegenOpCacheTotal();
}
// create modified hop dag (operator replacement and CSE)
if (!cplans.isEmpty()) {
// generate final hop dag
ret = constructModifiedHopDag(roots, cplans, clas);
// run common subexpression elimination and other rewrites
ret = rewriteCSE.rewriteHopDAG(ret, new ProgramRewriteStatus());
// explain after modification
if (LOG.isTraceEnabled()) {
LOG.trace("Codegen EXPLAIN (after optimize): \n" + Explain.explainHops(roots));
}
}
} catch (Exception ex) {
LOG.error("Codegen failed to optimize the following HOP DAG: \n" + Explain.explainHops(roots));
throw new DMLRuntimeException(ex);
}
if (DMLScript.STATISTICS) {
Statistics.incrementCodegenDAGCompile();
Statistics.incrementCodegenCompileTime(System.nanoTime() - t0);
}
Hop.resetVisitStatus(roots);
return ret;
}
Aggregations