Search in sources :

Example 56 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class SpoofCompiler method rExploreCPlans.

// //////////////////
// Codegen plan construction
private static void rExploreCPlans(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    // top-down memoization of processed dag nodes
    if (memo.contains(hop.getHopID()) || memo.containsHop(hop))
        return;
    // recursive candidate exploration
    for (Hop c : hop.getInput()) rExploreCPlans(c, memo, compileLiterals);
    // open initial operator plans, if possible
    for (TemplateBase tpl : TemplateUtils.TEMPLATES) if (tpl.open(hop))
        memo.addAll(hop, enumPlans(hop, null, tpl, memo));
    // fuse and merge operator plans
    for (Hop c : hop.getInput()) for (TemplateBase tpl : memo.getDistinctTemplates(c.getHopID())) if (tpl.fuse(hop, c))
        memo.addAll(hop, enumPlans(hop, c, tpl, memo));
    // close operator plans, if required
    if (memo.contains(hop.getHopID())) {
        Iterator<MemoTableEntry> iter = memo.get(hop.getHopID()).iterator();
        while (iter.hasNext()) {
            MemoTableEntry me = iter.next();
            TemplateBase tpl = TemplateUtils.createTemplate(me.type);
            CloseType ccode = tpl.close(hop);
            if (ccode == CloseType.CLOSED_INVALID)
                iter.remove();
            me.ctype = ccode;
        }
    }
    // prune subsumed / redundant plans
    if (PRUNE_REDUNDANT_PLANS) {
        memo.pruneRedundant(hop.getHopID(), PLAN_SEL_POLICY.isHeuristic(), null);
    }
    // mark visited even if no plans found (e.g., unsupported ops)
    memo.addHop(hop);
}
Also used : CloseType(org.apache.sysml.hops.codegen.template.TemplateBase.CloseType) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) TemplateBase(org.apache.sysml.hops.codegen.template.TemplateBase)

Example 57 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class PlanSelectionFuseCostBased method selectPlans.

private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) {
    // prune row aggregates with pure cellwise operations
    for (Long hopID : R) {
        MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
        if (me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
            List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
            memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist));
            if (LOG.isTraceEnabled()) {
                LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
            }
        }
    }
    // we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
    for (Long hopID : partition) {
        if (memo.countEntries(hopID, TemplateType.OUTER) == 2) {
            List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
            MemoTableEntry me1 = entries.get(0);
            MemoTableEntry me2 = entries.get(1);
            MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
            if (rmEntry != null) {
                memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
                memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
            }
        }
    }
    // if no materialization points, use basic fuse-all w/ partition awareness
    if (M == null || M.isEmpty()) {
        for (Long hopID : R) rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, partition);
    } else {
        // TODO branch and bound pruning, right now we use exhaustive enum for early experiments
        // via skip ahead in below enumeration algorithm
        // obtain hop compute costs per cell once
        HashMap<Long, Double> computeCosts = new HashMap<>();
        for (Long hopID : R) rGetComputeCosts(memo.getHopRefs().get(hopID), partition, computeCosts);
        // scan linearized search space, w/ skips for branch and bound pruning
        int len = (int) Math.pow(2, M.size());
        boolean[] bestPlan = null;
        double bestC = Double.MAX_VALUE;
        for (int i = 0; i < len; i++) {
            // construct assignment
            boolean[] plan = createAssignment(M.size(), i);
            // cost assignment on hops
            double C = getPlanCost(memo, partition, R, M, plan, computeCosts);
            if (LOG.isTraceEnabled())
                LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C);
            // cost comparisons
            if (bestPlan == null || C < bestC) {
                bestC = C;
                bestPlan = plan;
                if (LOG.isTraceEnabled())
                    LOG.trace("Enum: Found new best plan.");
            }
        }
        if (DMLScript.STATISTICS) {
            Statistics.incrementCodegenEnumAllP(len);
            Statistics.incrementCodegenEnumEval(len);
        }
        // prune memo table wrt best plan and select plans
        HashSet<Long> visited = new HashSet<>();
        for (Long hopID : R) rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), visited, partition, M, bestPlan);
        HashSet<Long> visited2 = new HashSet<>();
        for (Long hopID : R) rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), visited2, partition, M, bestPlan);
        for (Long hopID : R) rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, partition);
    }
}
Also used : HashMap(java.util.HashMap) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) HashSet(java.util.HashSet)

Example 58 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class PlanSelectionFuseCostBasedV2 method createAndAddMultiAggPlans.

// within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
    // create index of plans that reference full aggregates to avoid circular dependencies
    HashSet<Long> refHops = new HashSet<>();
    for (Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet()) if (!e.getValue().isEmpty()) {
        Hop hop = memo.getHopRefs().get(e.getKey());
        for (Hop c : hop.getInput()) refHops.add(c.getHopID());
    }
    // find all full aggregations (the fact that they are in the same partition guarantees
    // that they also have common subexpressions, also full aggregations are by def root nodes)
    ArrayList<Long> fullAggs = new ArrayList<>();
    for (Long hopID : R) {
        Hop root = memo.getHopRefs().get(hopID);
        if (!refHops.contains(hopID) && isMultiAggregateRoot(root))
            fullAggs.add(hopID);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    // construct and add multiagg template plans (w/ max 3 aggregations)
    for (int i = 0; i < fullAggs.size(); i += 3) {
        int ito = Math.min(i + 3, fullAggs.size());
        if (ito - i >= 2) {
            MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1, ito - i);
            if (isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; j++) {
                    memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
                    if (LOG.isTraceEnabled())
                        LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
                }
            } else if (LOG.isTraceEnabled()) {
                LOG.trace("Removed invalid multiagg plan: " + me);
            }
        }
    }
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet)

Example 59 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class TemplateUtils method getRowTemplateMatrixInput.

public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) {
    MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
    long ret = -1;
    for (int i = 0; ret < 0 && i < current.getInput().size(); i++) {
        Hop input = current.getInput().get(i);
        if (me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateType.ROW))
            ret = getRowTemplateMatrixInput(input, memo);
        else if (!me.isPlanRef(i) && isMatrix(input))
            ret = input.getHopID();
    }
    return ret;
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Aggregations

MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)59 Hop (org.apache.sysml.hops.Hop)46 HashSet (java.util.HashSet)28 ArrayList (java.util.ArrayList)22 List (java.util.List)20 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)20 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)19 HashMap (java.util.HashMap)16 Comparator (java.util.Comparator)15 BinaryOp (org.apache.sysml.hops.BinaryOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TemplateType (org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType)15 Entry (java.util.Map.Entry)13 IndexingOp (org.apache.sysml.hops.IndexingOp)13 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 TernaryOp (org.apache.sysml.hops.TernaryOp)13 AggOp (org.apache.sysml.hops.Hop.AggOp)11 CPlanMemoTable (org.apache.sysml.hops.codegen.template.CPlanMemoTable)10 Arrays (java.util.Arrays)9