Search in sources :

Example 6 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.

// across-partition multi-agg templates with shared reads
private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
    // collect full aggregations as initial set of candidates
    HashSet<Long> fullAggs = new HashSet<>();
    Hop.resetVisitStatus(roots);
    for (Hop hop : roots) rCollectFullAggregates(hop, fullAggs);
    Hop.resetVisitStatus(roots);
    // remove operators with assigned multi-agg plans
    fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
    // check applicability for further analysis
    if (fullAggs.size() <= 1)
        return;
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    // collect information for all candidates
    // (subsumed aggregations, and inputs to fused operators)
    List<AggregateInfo> aggInfos = new ArrayList<>();
    for (Long hopID : fullAggs) {
        Hop aggHop = memo.getHopRefs().get(hopID);
        AggregateInfo tmp = new AggregateInfo(aggHop);
        for (int i = 0; i < aggHop.getInput().size(); i++) {
            Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i == 0 ? aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
            rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
        }
        if (tmp._fusedInputs.isEmpty()) {
            if (HopRewriteUtils.isMatrixMultiply(aggHop)) {
                tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
                tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
            } else
                tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
        }
        aggInfos.add(tmp);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
        for (AggregateInfo info : aggInfos) LOG.trace(info);
    }
    // sort aggregations by num dependencies to simplify merging
    // clusters of aggregations with parallel dependencies
    aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
    // greedy grouping of multi-agg candidates
    boolean converged = false;
    while (!converged) {
        AggregateInfo merged = null;
        for (int i = 0; i < aggInfos.size(); i++) {
            AggregateInfo current = aggInfos.get(i);
            for (int j = i + 1; j < aggInfos.size(); j++) {
                AggregateInfo that = aggInfos.get(j);
                if (current.isMergable(that)) {
                    merged = current.merge(that);
                    aggInfos.remove(j);
                    j--;
                }
            }
        }
        converged = (merged == null);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Merged across-partition ua(RC) aggregation info: ");
        for (AggregateInfo info : aggInfos) LOG.trace(info);
    }
    // construct and add multiagg template plans (w/ max 3 aggregations)
    for (AggregateInfo info : aggInfos) {
        if (info._aggregates.size() <= 1)
            continue;
        Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
        MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, aggs[0], aggs[1], (aggs.length > 2) ? aggs[2] : -1, aggs.length);
        for (int i = 0; i < aggs.length; i++) {
            memo.add(memo.getHopRefs().get(aggs[i]), me);
            addBestPlan(aggs[i], me);
            if (LOG.isTraceEnabled())
                LOG.trace("Added multiagg* plan: " + aggs[i] + " " + me);
        }
    }
}
Also used : TemplateRow(org.apache.sysml.hops.codegen.template.TemplateRow) Arrays(java.util.Arrays) IndexingOp(org.apache.sysml.hops.IndexingOp) HashMap(java.util.HashMap) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) TemplateOuterProduct(org.apache.sysml.hops.codegen.template.TemplateOuterProduct) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggOp(org.apache.sysml.hops.Hop.AggOp) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) HashSet(java.util.HashSet) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Pair(org.apache.commons.lang3.tuple.Pair) ReorgOp(org.apache.sysml.hops.ReorgOp) IDSequence(org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence) CollectionUtils(org.apache.commons.collections.CollectionUtils) InfrastructureAnalyzer(org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Statistics(org.apache.sysml.utils.Statistics) TernaryOp(org.apache.sysml.hops.TernaryOp) Iterator(java.util.Iterator) Collection(java.util.Collection) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) BinaryOp(org.apache.sysml.hops.BinaryOp) TemplateUtils(org.apache.sysml.hops.codegen.template.TemplateUtils) Collectors(java.util.stream.Collectors) Direction(org.apache.sysml.hops.Hop.Direction) Hop(org.apache.sysml.hops.Hop) List(java.util.List) Entry(java.util.Map.Entry) DMLScript(org.apache.sysml.api.DMLScript) Log(org.apache.commons.logging.Log) LogFactory(org.apache.commons.logging.LogFactory) UtilFunctions(org.apache.sysml.runtime.util.UtilFunctions) Comparator(java.util.Comparator) Collections(java.util.Collections) HopRewriteUtils(org.apache.sysml.hops.rewrite.HopRewriteUtils) UnaryOp(org.apache.sysml.hops.UnaryOp) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet)

Example 7 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class SpoofCompiler method optimize.

/**
 * Main interface of sum-product optimizer, statement block dag.
 *
 * @param roots dag root nodes
 * @param recompile true if invoked during dynamic recompilation
 * @return dag root nodes of modified dag
 */
public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean recompile) {
    if (roots == null || roots.isEmpty())
        return roots;
    long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
    ArrayList<Hop> ret = roots;
    try {
        // context-sensitive literal replacement (only integers during recompile)
        boolean compileLiterals = (PLAN_CACHE_POLICY == PlanCachePolicy.CONSTANT) || !recompile;
        // candidate exploration of valid partial fusion plans
        CPlanMemoTable memo = new CPlanMemoTable();
        for (Hop hop : roots) rExploreCPlans(hop, memo, compileLiterals);
        // candidate selection of optimal fusion plan
        memo.pruneSuboptimal(roots);
        // construct actual cplan representations
        // note: we do not use the hop visit status due to jumps over fused operators which would
        // corrupt subsequent resets, leaving partial hops dags in visited status
        HashMap<Long, Pair<Hop[], CNodeTpl>> cplans = new LinkedHashMap<>();
        HashSet<Long> visited = new HashSet<>();
        for (Hop hop : roots) rConstructCPlans(hop, memo, cplans, compileLiterals, visited);
        // cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping,
        // remove empty templates with single cnodedata input, remove spurious lookups,
        // perform common subexpression elimination)
        cplans = cleanupCPlans(memo, cplans);
        // explain before modification
        if (LOG.isTraceEnabled() && !cplans.isEmpty()) {
            // existing cplans
            LOG.trace("Codegen EXPLAIN (before optimize): \n" + Explain.explainHops(roots));
        }
        // source code generation for all cplans
        HashMap<Long, Pair<Hop[], Class<?>>> clas = new HashMap<>();
        for (Entry<Long, Pair<Hop[], CNodeTpl>> cplan : cplans.entrySet()) {
            Pair<Hop[], CNodeTpl> tmp = cplan.getValue();
            Class<?> cla = planCache.getPlan(tmp.getValue());
            if (cla == null) {
                // generate java source code
                String src = tmp.getValue().codegen(false);
                // explain debug output cplans or generated source code
                if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isHopsType(recompile)) {
                    LOG.info("Codegen EXPLAIN (generated cplan for HopID: " + cplan.getKey() + ", line " + tmp.getValue().getBeginLine() + ", hash=" + tmp.getValue().hashCode() + "):");
                    LOG.info(tmp.getValue().getClassname() + Explain.explainCPlan(cplan.getValue().getValue()));
                }
                if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isRuntimeType(recompile)) {
                    LOG.info("Codegen EXPLAIN (generated code for HopID: " + cplan.getKey() + ", line " + tmp.getValue().getBeginLine() + ", hash=" + tmp.getValue().hashCode() + "):");
                    LOG.info(src);
                }
                // compile generated java source code
                cla = CodegenUtils.compileClass("codegen." + tmp.getValue().getClassname(), src);
                // maintain plan cache
                if (PLAN_CACHE_POLICY != PlanCachePolicy.NONE)
                    planCache.putPlan(tmp.getValue(), cla);
            } else if (DMLScript.STATISTICS) {
                Statistics.incrementCodegenOpCacheHits();
            }
            // make class available and maintain hits
            if (cla != null)
                clas.put(cplan.getKey(), new Pair<Hop[], Class<?>>(tmp.getKey(), cla));
            if (DMLScript.STATISTICS)
                Statistics.incrementCodegenOpCacheTotal();
        }
        // create modified hop dag (operator replacement and CSE)
        if (!cplans.isEmpty()) {
            // generate final hop dag
            ret = constructModifiedHopDag(roots, cplans, clas);
            // run common subexpression elimination and other rewrites
            ret = rewriteCSE.rewriteHopDAG(ret, new ProgramRewriteStatus());
            // explain after modification
            if (LOG.isTraceEnabled()) {
                LOG.trace("Codegen EXPLAIN (after optimize): \n" + Explain.explainHops(roots));
            }
        }
    } catch (Exception ex) {
        LOG.error("Codegen failed to optimize the following HOP DAG: \n" + Explain.explainHops(roots));
        throw new DMLRuntimeException(ex);
    }
    if (DMLScript.STATISTICS) {
        Statistics.incrementCodegenDAGCompile();
        Statistics.incrementCodegenCompileTime(System.nanoTime() - t0);
    }
    Hop.resetVisitStatus(roots);
    return ret;
}
Also used : CNodeTpl(org.apache.sysml.hops.codegen.cplan.CNodeTpl) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Hop(org.apache.sysml.hops.Hop) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LinkedHashMap(java.util.LinkedHashMap) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus) Pair(org.apache.sysml.runtime.matrix.data.Pair) HashSet(java.util.HashSet)

Aggregations

HashSet (java.util.HashSet)7 Hop (org.apache.sysml.hops.Hop)7 CPlanMemoTable (org.apache.sysml.hops.codegen.template.CPlanMemoTable)7 ArrayList (java.util.ArrayList)5 Comparator (java.util.Comparator)5 HashMap (java.util.HashMap)5 List (java.util.List)5 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)5 TemplateType (org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType)5 Entry (java.util.Map.Entry)4 UtilFunctions (org.apache.sysml.runtime.util.UtilFunctions)4 Arrays (java.util.Arrays)3 Collection (java.util.Collection)3 Collections (java.util.Collections)3 Iterator (java.util.Iterator)3 LinkedHashMap (java.util.LinkedHashMap)3 Collectors (java.util.stream.Collectors)3 CollectionUtils (org.apache.commons.collections.CollectionUtils)3 Pair (org.apache.commons.lang3.tuple.Pair)3 Log (org.apache.commons.logging.Log)3