Search in sources :

Example 31 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseNoRedundancy method rSelectPlans.

private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) {
    if (isVisited(current.getHopID(), currentType))
        return;
    //step 0: remove plans that refer to a common partial plan
    if (memo.contains(current.getHopID())) {
        HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
        List<MemoTableEntry> hopP = memo.get(current.getHopID());
        for (MemoTableEntry e1 : hopP) for (int i = 0; i < 3; i++) if (e1.isPlanRef(i) && current.getInput().get(i).getParent().size() > 1)
            //remove references to hops w/ multiple consumers
            rmSet.add(e1);
        memo.remove(current, rmSet);
    }
    //step 1: prune subsumed plans of same type
    if (memo.contains(current.getHopID())) {
        HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
        List<MemoTableEntry> hopP = memo.get(current.getHopID());
        for (MemoTableEntry e1 : hopP) for (MemoTableEntry e2 : hopP) if (e1 != e2 && e1.subsumes(e2))
            rmSet.add(e2);
        memo.remove(current, rmSet);
    }
    //step 2: select plan for current path
    MemoTableEntry best = null;
    if (memo.contains(current.getHopID())) {
        if (currentType == null) {
            best = memo.get(current.getHopID()).stream().filter(p -> isValid(p, current)).min(new BasicPlanComparator()).orElse(null);
        } else {
            best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CellTpl).min(Comparator.comparing(p -> 7 - ((p.type == currentType) ? 4 : 0) - p.countPlanRefs())).orElse(null);
        }
        addBestPlan(current.getHopID(), best);
    }
    //step 3: recursively process children
    for (int i = 0; i < current.getInput().size(); i++) {
        TemplateType pref = (best != null && best.isPlanRef(i)) ? best.type : null;
        rSelectPlans(memo, current.getInput().get(i), pref);
    }
    setVisited(current.getHopID(), currentType);
}
Also used : HashSet(java.util.HashSet) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) List(java.util.List) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) Entry(java.util.Map.Entry) Comparator(java.util.Comparator) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) HashSet(java.util.HashSet)

Example 32 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rPruneInvalidPlans.

private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
    //memoization (not via hops because in middle of dag)
    if (visited.contains(current.getHopID()))
        return;
    //process children recursively
    for (Hop c : current.getInput()) rPruneInvalidPlans(memo, c, visited, partition, M, plan);
    //find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs, 
    //i.e., plans that become invalid after the previous pruning step
    long hopID = current.getHopID();
    if (partition.contains(hopID) && memo.contains(hopID, TemplateType.RowTpl)) {
        for (MemoTableEntry me : memo.get(hopID)) {
            if (me.type == TemplateType.RowTpl) {
                //convert leaf node with pure vector inputs
                if (!me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current)) {
                    me.type = TemplateType.CellTpl;
                    if (LOG.isTraceEnabled())
                        LOG.trace("Converted leaf memo table entry from row to cell: " + me);
                }
                //convert inner node without row template input
                if (me.hasPlanRef() && !ROW_TPL.open(current)) {
                    boolean hasRowInput = false;
                    for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
                        hasRowInput |= memo.contains(me.input(i), TemplateType.RowTpl);
                    if (!hasRowInput) {
                        me.type = TemplateType.CellTpl;
                        if (LOG.isTraceEnabled())
                            LOG.trace("Converted inner memo table entry from row to cell: " + me);
                    }
                }
            }
        }
    }
    visited.add(current.getHopID());
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Example 33 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method getPartitionRootNodes.

private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, HashSet<Long> partition) {
    //build inverted index of references entries 
    HashSet<Long> ix = new HashSet<Long>();
    for (Long hopID : partition) if (memo.contains(hopID))
        for (MemoTableEntry me : memo.get(hopID)) {
            ix.add(me.input1);
            ix.add(me.input2);
            ix.add(me.input3);
        }
    HashSet<Long> roots = new HashSet<Long>();
    for (Long hopID : partition) if (!ix.contains(hopID))
        roots.add(hopID);
    return roots;
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) HashSet(java.util.HashSet)

Example 34 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.

//within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
    //create index of plans that reference full aggregates to avoid circular dependencies
    HashSet<Long> refHops = new HashSet<Long>();
    for (Entry<Long, List<MemoTableEntry>> e : memo._plans.entrySet()) if (!e.getValue().isEmpty()) {
        Hop hop = memo._hopRefs.get(e.getKey());
        for (Hop c : hop.getInput()) refHops.add(c.getHopID());
    }
    //find all full aggregations (the fact that they are in the same partition guarantees 
    //that they also have common subexpressions, also full aggregations are by def root nodes)
    ArrayList<Long> fullAggs = new ArrayList<Long>();
    for (Long hopID : R) {
        Hop root = memo._hopRefs.get(hopID);
        if (!refHops.contains(hopID) && root instanceof AggUnaryOp && ((AggUnaryOp) root).getDirection() == Direction.RowCol)
            fullAggs.add(hopID);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    //construct and add multiagg template plans (w/ max 3 aggregations)
    for (int i = 0; i < fullAggs.size(); i += 3) {
        int ito = Math.min(i + 3, fullAggs.size());
        if (ito - i >= 2) {
            MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1);
            if (isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; j++) {
                    memo.add(memo._hopRefs.get(fullAggs.get(j)), me);
                    if (LOG.isTraceEnabled())
                        LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
                }
            } else if (LOG.isTraceEnabled()) {
                LOG.trace("Removed invalid multiagg plan: " + me);
            }
        }
    }
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) HashSet(java.util.HashSet)

Example 35 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method selectPlans.

private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) {
    //prune row aggregates with pure cellwise operations
    for (Long hopID : R) {
        MemoTableEntry me = memo.getBest(hopID, TemplateType.RowTpl);
        if (me.type == TemplateType.RowTpl && memo.contains(hopID, TemplateType.CellTpl) && rIsRowTemplateWithoutAgg(memo, memo._hopRefs.get(hopID), new HashSet<Long>())) {
            List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.RowTpl);
            memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(blacklist));
            if (LOG.isTraceEnabled()) {
                LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
            }
        }
    }
    //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
    for (Long hopID : partition) {
        if (memo.countEntries(hopID, TemplateType.OuterProdTpl) == 2) {
            List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OuterProdTpl);
            MemoTableEntry me1 = entries.get(0);
            MemoTableEntry me2 = entries.get(1);
            MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
            if (rmEntry != null) {
                memo.remove(memo._hopRefs.get(hopID), new HashSet<MemoTableEntry>(Arrays.asList(rmEntry)));
                memo._plansBlacklist.remove(rmEntry.input(rmEntry.getPlanRefIndex()));
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
            }
        }
    }
    //if no materialization points, use basic fuse-all w/ partition awareness
    if (M == null || M.isEmpty()) {
        for (Long hopID : R) rSelectPlansFuseAll(memo, memo._hopRefs.get(hopID), null, partition);
    } else {
        //TODO branch and bound pruning, right now we use exhaustive enum for early experiments
        //via skip ahead in below enumeration algorithm
        //obtain hop compute costs per cell once
        HashMap<Long, Double> computeCosts = new HashMap<Long, Double>();
        for (Long hopID : R) rGetComputeCosts(memo._hopRefs.get(hopID), partition, computeCosts);
        //scan linearized search space, w/ skips for branch and bound pruning
        int len = (int) Math.pow(2, M.size());
        boolean[] bestPlan = null;
        double bestC = Double.MAX_VALUE;
        for (int i = 0; i < len; i++) {
            //construct assignment
            boolean[] plan = createAssignment(M.size(), i);
            //cost assignment on hops
            double C = getPlanCost(memo, partition, R, M, plan, computeCosts);
            if (LOG.isTraceEnabled())
                LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C);
            //cost comparisons
            if (bestPlan == null || C < bestC) {
                bestC = C;
                bestPlan = plan;
                if (LOG.isTraceEnabled())
                    LOG.trace("Enum: Found new best plan.");
            }
        }
        //prune memo table wrt best plan and select plans
        HashSet<Long> visited = new HashSet<Long>();
        for (Long hopID : R) rPruneSuboptimalPlans(memo, memo._hopRefs.get(hopID), visited, partition, M, bestPlan);
        HashSet<Long> visited2 = new HashSet<Long>();
        for (Long hopID : R) rPruneInvalidPlans(memo, memo._hopRefs.get(hopID), visited2, partition, M, bestPlan);
        for (Long hopID : R) rSelectPlansFuseAll(memo, memo._hopRefs.get(hopID), null, partition);
    }
}
Also used : HashMap(java.util.HashMap) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) HashSet(java.util.HashSet)

Aggregations

MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)59 Hop (org.apache.sysml.hops.Hop)46 HashSet (java.util.HashSet)28 ArrayList (java.util.ArrayList)22 List (java.util.List)20 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)20 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)19 HashMap (java.util.HashMap)16 Comparator (java.util.Comparator)15 BinaryOp (org.apache.sysml.hops.BinaryOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TemplateType (org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType)15 Entry (java.util.Map.Entry)13 IndexingOp (org.apache.sysml.hops.IndexingOp)13 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 TernaryOp (org.apache.sysml.hops.TernaryOp)13 AggOp (org.apache.sysml.hops.Hop.AggOp)11 CPlanMemoTable (org.apache.sysml.hops.codegen.template.CPlanMemoTable)10 Arrays (java.util.Arrays)9