Search in sources :

Example 36 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.

//within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
    //create index of plans that reference full aggregates to avoid circular dependencies
    HashSet<Long> refHops = new HashSet<Long>();
    for (Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet()) if (!e.getValue().isEmpty()) {
        Hop hop = memo._hopRefs.get(e.getKey());
        for (Hop c : hop.getInput()) refHops.add(c.getHopID());
    }
    //find all full aggregations (the fact that they are in the same partition guarantees 
    //that they also have common subexpressions, also full aggregations are by def root nodes)
    ArrayList<Long> fullAggs = new ArrayList<Long>();
    for (Long hopID : R) {
        Hop root = memo._hopRefs.get(hopID);
        if (!refHops.contains(hopID) && root instanceof AggUnaryOp && ((AggUnaryOp) root).getDirection() == Direction.RowCol)
            fullAggs.add(hopID);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    //construct and add multiagg template plans (w/ max 3 aggregations)
    for (int i = 0; i < fullAggs.size(); i += 3) {
        int ito = Math.min(i + 3, fullAggs.size());
        if (ito - i >= 2) {
            MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1);
            if (isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; j++) {
                    memo.add(memo._hopRefs.get(fullAggs.get(j)), me);
                    if (LOG.isTraceEnabled())
                        LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
                }
            } else if (LOG.isTraceEnabled()) {
                LOG.trace("Removed invalid multiagg plan: " + me);
            }
        }
    }
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) HashSet(java.util.HashSet)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)36 Hop (org.apache.sysml.hops.Hop)33 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)19 LiteralOp (org.apache.sysml.hops.LiteralOp)14 BinaryOp (org.apache.sysml.hops.BinaryOp)12 ReorgOp (org.apache.sysml.hops.ReorgOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 ArrayList (java.util.ArrayList)5 TernaryOp (org.apache.sysml.hops.TernaryOp)5 DataOp (org.apache.sysml.hops.DataOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2