Search in sources :

Example 41 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method createAndAddMultiAggPlans.

// within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
    // create index of plans that reference full aggregates to avoid circular dependencies
    HashSet<Long> refHops = new HashSet<>();
    for (Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet()) if (!e.getValue().isEmpty()) {
        Hop hop = memo.getHopRefs().get(e.getKey());
        for (Hop c : hop.getInput()) refHops.add(c.getHopID());
    }
    // find all full aggregations (the fact that they are in the same partition guarantees
    // that they also have common subexpressions, also full aggregations are by def root nodes)
    ArrayList<Long> fullAggs = new ArrayList<>();
    for (Long hopID : R) {
        Hop root = memo.getHopRefs().get(hopID);
        if (!refHops.contains(hopID) && isMultiAggregateRoot(root))
            fullAggs.add(hopID);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    // construct and add multiagg template plans (w/ max 3 aggregations)
    for (int i = 0; i < fullAggs.size(); i += 3) {
        int ito = Math.min(i + 3, fullAggs.size());
        if (ito - i >= 2) {
            MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1, ito - i);
            if (isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; j++) {
                    memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
                    if (LOG.isTraceEnabled())
                        LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
                }
            } else if (LOG.isTraceEnabled()) {
                LOG.trace("Removed invalid multiagg plan: " + me);
            }
        }
    }
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) HashSet(java.util.HashSet)

Example 42 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rPruneSuboptimalPlans.

private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
    // memoization (not via hops because in middle of dag)
    if (visited.contains(current.getHopID()))
        return;
    // remove memo table entries if necessary
    long hopID = current.getHopID();
    if (partition.contains(hopID) && memo.contains(hopID)) {
        Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
        while (iter.hasNext()) {
            MemoTableEntry me = iter.next();
            if (!hasNoRefToMaterialization(me, M, plan) && me.type != TemplateType.OUTER) {
                iter.remove();
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed memo table entry: " + me);
            }
        }
    }
    // process children recursively
    for (Hop c : current.getInput()) rPruneSuboptimalPlans(memo, c, visited, partition, M, plan);
    visited.add(current.getHopID());
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Example 43 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBasedV2 method rGetPlanCosts.

private double rGetPlanCosts(CPlanMemoTable memo, final Hop current, HashSet<VisitMarkCost> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType, final double costBound) {
    final long currentHopId = current.getHopID();
    // costs of complex operation DAGs within a single fused operator
    if (!visited.add(new VisitMarkCost(currentHopId, (costsCurrent == null || currentType == TemplateType.MAGG) ? -1 : costsCurrent.ID)))
        // already existing
        return 0;
    // open template if necessary, including memoization
    // under awareness of current plan choice
    MemoTableEntry best = null;
    boolean opened = (currentType == null);
    if (memo.contains(currentHopId)) {
        // use streams, lambda expressions, etc to avoid unnecessary overhead
        if (currentType == null) {
            for (MemoTableEntry me : memo.get(currentHopId)) best = me.isValid() && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && BasicPlanComparator.icompare(me, best) < 0 ? me : best;
            opened = true;
        } else {
            for (MemoTableEntry me : memo.get(currentHopId)) best = (me.type == currentType || me.type == TemplateType.CELL) && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && TypedPlanComparator.icompare(me, best, currentType) < 0 ? me : best;
        }
    }
    // create new cost vector if opened, initialized with write costs
    CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current));
    double costs = 0;
    // add other roots for multi-agg template to account for shared costs
    if (opened && best != null && best.type == TemplateType.MAGG) {
        // account costs to first multi-agg root
        if (best.input1 == currentHopId)
            for (int i = 1; i < 3; i++) {
                if (!best.isPlanRef(i))
                    continue;
                costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG, costBound - costs);
                if (costs >= costBound)
                    return Double.POSITIVE_INFINITY;
            }
        else
            return 0;
    }
    // add compute costs of current operator to costs vector
    costVect.computeCosts += computeCosts.get(currentHopId);
    // process children recursively
    for (int i = 0; i < current.getInput().size(); i++) {
        Hop c = current.getInput().get(i);
        if (best != null && best.isPlanRef(i))
            costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type, costBound - costs);
        else if (best != null && isImplicitlyFused(current, i, best.type))
            costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
        else {
            // include children and I/O costs
            if (part.getPartition().contains(c.getHopID()))
                costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
            if (costVect != null && c.getDataType().isMatrix())
                costVect.addInputSize(c.getHopID(), getSize(c));
        }
        if (costs >= costBound)
            return Double.POSITIVE_INFINITY;
    }
    // add costs for opened fused operator
    if (opened) {
        double memInputs = sumInputMemoryEstimates(memo, costVect);
        double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM + Math.max(memInputs / READ_BANDWIDTH_MEM, costVect.computeCosts / COMPUTE_BANDWIDTH);
        // read correction for distributed computation
        if (memInputs > OptimizerUtils.getLocalMemBudget())
            tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST;
        // sparsity correction for outer-product template (and sparse-safe cell)
        Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
        if (best != null && best.type == TemplateType.OUTER)
            tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
        else // write correction for known evictions in CP
        if (memInputs <= OptimizerUtils.getLocalMemBudget() && sumTmpInputOutputSize(memo, costVect) * 8 > LazyWriteBuffer.getWriteBufferLimit())
            tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO;
        costs += tmpCosts;
        if (LOG.isTraceEnabled()) {
            String type = (best != null) ? best.type.name() : "HOP";
            LOG.trace("Cost vector (" + type + " " + currentHopId + "): " + costVect + " -> " + tmpCosts);
        }
    } else // add costs for non-partition read in the middle of fused operator
    if (part.getExtConsumed().contains(current.getHopID())) {
        costs += rGetPlanCosts(memo, current, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
    }
    // sanity check non-negative costs
    if (costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs))
        throw new RuntimeException("Wrong cost estimate: " + costs);
    return costs;
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Example 44 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBasedV2 method createAndAddMultiAggPlans.

// within-partition multi-agg templates
private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
    // create index of plans that reference full aggregates to avoid circular dependencies
    HashSet<Long> refHops = new HashSet<>();
    for (Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet()) if (!e.getValue().isEmpty()) {
        Hop hop = memo.getHopRefs().get(e.getKey());
        for (Hop c : hop.getInput()) refHops.add(c.getHopID());
    }
    // find all full aggregations (the fact that they are in the same partition guarantees
    // that they also have common subexpressions, also full aggregations are by def root nodes)
    ArrayList<Long> fullAggs = new ArrayList<>();
    for (Long hopID : R) {
        Hop root = memo.getHopRefs().get(hopID);
        if (!refHops.contains(hopID) && isMultiAggregateRoot(root))
            fullAggs.add(hopID);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found within-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    // construct and add multiagg template plans (w/ max 3 aggregations)
    for (int i = 0; i < fullAggs.size(); i += 3) {
        int ito = Math.min(i + 3, fullAggs.size());
        if (ito - i >= 2) {
            MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, fullAggs.get(i), fullAggs.get(i + 1), ((ito - i) == 3) ? fullAggs.get(i + 2) : -1, ito - i);
            if (isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; j++) {
                    memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
                    if (LOG.isTraceEnabled())
                        LOG.trace("Added multiagg plan: " + fullAggs.get(j) + " " + me);
                }
            } else if (LOG.isTraceEnabled()) {
                LOG.trace("Removed invalid multiagg plan: " + me);
            }
        }
    }
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet)

Example 45 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project incubator-systemml by apache.

the class PlanSelectionFuseCostBasedV2 method rCollectDependentRowOps.

private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part, HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateType type, boolean foundRowOp) {
    // avoid redundant evaluation of processed and non-partition nodes
    Pair<Long, Integer> key = Pair.of(hop.getHopID(), (foundRowOp ? Short.MAX_VALUE : 0) + ((type != null) ? type.ordinal() + 1 : 0));
    if (visited.contains(key) || !part.getPartition().contains(hop.getHopID())) {
        return;
    }
    // process node itself (top-down)
    MemoTableEntry me = (type == null) ? memo.getBest(hop.getHopID()) : memo.getBest(hop.getHopID(), type);
    boolean inRow = (me != null && me.type == TemplateType.ROW && type == TemplateType.ROW);
    boolean diffPlans = // guard against plan differences
    part.getMatPointsExt().length > 0 && memo.contains(hop.getHopID(), TemplateType.ROW) && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
    if (inRow && foundRowOp)
        blacklist.add(hop.getHopID());
    if (isRowAggOp(hop, inRow) || diffPlans) {
        blacklist.add(hop.getHopID());
        foundRowOp = true;
    }
    // process children recursively
    for (int i = 0; i < hop.getInput().size(); i++) {
        boolean lfoundRowOp = foundRowOp && me != null && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type));
        rCollectDependentRowOps(hop.getInput().get(i), memo, part, blacklist, visited, me != null ? me.type : null, lfoundRowOp);
    }
    // process node itself (bottom-up)
    if (!blacklist.contains(hop.getHopID())) {
        for (int i = 0; i < hop.getInput().size(); i++) if (me != null && me.type == TemplateType.ROW && (me.isPlanRef(i) || isImplicitlyFused(hop, i, me.type)) && blacklist.contains(hop.getInput().get(i).getHopID())) {
            blacklist.add(hop.getHopID());
        }
    }
    visited.add(key);
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)

Aggregations

MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)59 Hop (org.apache.sysml.hops.Hop)46 HashSet (java.util.HashSet)28 ArrayList (java.util.ArrayList)22 List (java.util.List)20 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)20 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)19 HashMap (java.util.HashMap)16 Comparator (java.util.Comparator)15 BinaryOp (org.apache.sysml.hops.BinaryOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TemplateType (org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType)15 Entry (java.util.Map.Entry)13 IndexingOp (org.apache.sysml.hops.IndexingOp)13 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 TernaryOp (org.apache.sysml.hops.TernaryOp)13 AggOp (org.apache.sysml.hops.Hop.AggOp)11 CPlanMemoTable (org.apache.sysml.hops.codegen.template.CPlanMemoTable)10 Arrays (java.util.Arrays)9