Search in sources :

Example 16 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class PlanSelectionFuseCostBasedV2 method rGetPlanCosts.

private double rGetPlanCosts(CPlanMemoTable memo, final Hop current, HashSet<VisitMarkCost> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType, final double costBound) {
    final long currentHopId = current.getHopID();
    // costs of complex operation DAGs within a single fused operator
    if (!visited.add(new VisitMarkCost(currentHopId, (costsCurrent == null || currentType == TemplateType.MAGG) ? -1 : costsCurrent.ID)))
        // already existing
        return 0;
    // open template if necessary, including memoization
    // under awareness of current plan choice
    MemoTableEntry best = null;
    boolean opened = (currentType == null);
    if (memo.contains(currentHopId)) {
        // use streams, lambda expressions, etc to avoid unnecessary overhead
        if (currentType == null) {
            for (MemoTableEntry me : memo.get(currentHopId)) best = me.isValid() && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && BasicPlanComparator.icompare(me, best) < 0 ? me : best;
            opened = true;
        } else {
            for (MemoTableEntry me : memo.get(currentHopId)) best = (me.type == currentType || me.type == TemplateType.CELL) && hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && TypedPlanComparator.icompare(me, best, currentType) < 0 ? me : best;
        }
    }
    // create new cost vector if opened, initialized with write costs
    CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current));
    double costs = 0;
    // add other roots for multi-agg template to account for shared costs
    if (opened && best != null && best.type == TemplateType.MAGG) {
        // account costs to first multi-agg root
        if (best.input1 == currentHopId)
            for (int i = 1; i < 3; i++) {
                if (!best.isPlanRef(i))
                    continue;
                costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG, costBound - costs);
                if (costs >= costBound)
                    return Double.POSITIVE_INFINITY;
            }
        else
            return 0;
    }
    // add compute costs of current operator to costs vector
    costVect.computeCosts += computeCosts.get(currentHopId);
    // process children recursively
    for (int i = 0; i < current.getInput().size(); i++) {
        Hop c = current.getInput().get(i);
        if (best != null && best.isPlanRef(i))
            costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type, costBound - costs);
        else if (best != null && isImplicitlyFused(current, i, best.type))
            costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
        else {
            // include children and I/O costs
            if (part.getPartition().contains(c.getHopID()))
                costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
            if (costVect != null && c.getDataType().isMatrix())
                costVect.addInputSize(c.getHopID(), getSize(c));
        }
        if (costs >= costBound)
            return Double.POSITIVE_INFINITY;
    }
    // add costs for opened fused operator
    if (opened) {
        double memInputs = sumInputMemoryEstimates(memo, costVect);
        double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM + Math.max(memInputs / READ_BANDWIDTH_MEM, costVect.computeCosts / COMPUTE_BANDWIDTH);
        // read correction for distributed computation
        if (memInputs > OptimizerUtils.getLocalMemBudget())
            tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST;
        // sparsity correction for outer-product template (and sparse-safe cell)
        Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
        if (best != null && best.type == TemplateType.OUTER)
            tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
        else // write correction for known evictions in CP
        if (memInputs <= OptimizerUtils.getLocalMemBudget() && sumTmpInputOutputSize(memo, costVect) * 8 > LazyWriteBuffer.getWriteBufferLimit())
            tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO;
        costs += tmpCosts;
        if (LOG.isTraceEnabled()) {
            String type = (best != null) ? best.type.name() : "HOP";
            LOG.trace("Cost vector (" + type + " " + currentHopId + "): " + costVect + " -> " + tmpCosts);
        }
    } else // add costs for non-partition read in the middle of fused operator
    if (part.getExtConsumed().contains(current.getHopID())) {
        costs += rGetPlanCosts(memo, current, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
    }
    // sanity check non-negative costs
    if (costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs))
        throw new RuntimeException("Wrong cost estimate: " + costs);
    return costs;
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Example 17 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class PlanSelectionFuseCostBasedV2 method pruneInvalidAndSpecialCasePlans.

private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) {
    // prune invalid row entries w/ violated blocksize constraint
    if (OptimizerUtils.isSparkExecutionMode()) {
        for (Long hopID : part.getPartition()) {
            if (!memo.contains(hopID, TemplateType.ROW))
                continue;
            Hop hop = memo.getHopRefs().get(hopID);
            boolean isSpark = DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop, true) > OptimizerUtils.getLocalMemBudget();
            boolean validNcol = hop.getDataType().isScalar() || (HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim1() <= hop.getRowsInBlock() : hop.getDim2() <= hop.getColsInBlock());
            for (Hop in : hop.getInput()) validNcol &= in.getDataType().isScalar() || (in.getDim2() <= in.getColsInBlock()) || (hop instanceof AggBinaryOp && in.getDim1() <= in.getRowsInBlock() && HopRewriteUtils.isTransposeOperation(in));
            if (isSpark && !validNcol) {
                List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
                memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW);
                memo.removeAllRefTo(hopID, TemplateType.ROW);
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Removed row memo table entries w/ violated blocksize constraint (" + hopID + "): " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
                }
            }
        }
    }
    // prune row aggregates with pure cellwise operations
    // (we determine a blacklist of all operators in a partition that either
    // depend upon row aggregates or on which row aggregates depend)
    HashSet<Long> blacklist = collectIrreplaceableRowOps(memo, part);
    for (Long hopID : part.getPartition()) {
        if (blacklist.contains(hopID))
            continue;
        MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
        if (me != null && me.type == TemplateType.ROW && memo.hasOnlyExactMatches(hopID, TemplateType.ROW, TemplateType.CELL)) {
            List<MemoTableEntry> rmList = memo.get(hopID, TemplateType.ROW);
            memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(rmList));
            if (LOG.isTraceEnabled()) {
                LOG.trace("Removed row memo table entries w/o aggregation: " + Arrays.toString(rmList.toArray(new MemoTableEntry[0])));
            }
        }
    }
    // we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
    for (Long hopID : part.getPartition()) {
        if (memo.countEntries(hopID, TemplateType.OUTER) == 2) {
            List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
            MemoTableEntry me1 = entries.get(0);
            MemoTableEntry me2 = entries.get(1);
            MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
            if (rmEntry != null) {
                memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
                memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
            }
        }
    }
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Example 18 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class TemplateMultiAgg method constructCplan.

@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    // get all root nodes for multi aggregation
    MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MAGG);
    ArrayList<Hop> roots = new ArrayList<>();
    for (int i = 0; i < 3; i++) if (multiAgg.isPlanRef(i))
        roots.add(memo._hopRefs.get(multiAgg.input(i)));
    Hop.resetVisitStatus(roots);
    // recursively process required cplan outputs
    HashSet<Hop> inHops = new HashSet<>();
    HashMap<Long, CNode> tmp = new HashMap<>();
    for (// use celltpl cplan construction
    Hop root : // use celltpl cplan construction
    roots) super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
    Hop.resetVisitStatus(roots);
    // reorder inputs (ensure matrices/vectors come first) and prune literals
    // note: we order by number of cells and subsequently sparsity to ensure
    // that sparse inputs are used as the main input w/o unnecessary conversion
    Hop shared = getSparseSafeSharedInput(roots, inHops);
    Hop[] sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator(shared)).toArray(Hop[]::new);
    // construct template node
    ArrayList<CNode> inputs = new ArrayList<>();
    for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
    ArrayList<CNode> outputs = new ArrayList<>();
    ArrayList<AggOp> aggOps = new ArrayList<>();
    for (Hop root : roots) {
        CNode node = tmp.get(root.getHopID());
        if (// add indexing ops for sideways data inputs
        node instanceof CNodeData && ((CNodeData) inputs.get(0)).getHopID() != ((CNodeData) node).getHopID())
            node = new CNodeUnary(node, (roots.get(0).getDim2() == 1) ? UnaryType.LOOKUP_R : UnaryType.LOOKUP_RC);
        outputs.add(node);
        aggOps.add(TemplateUtils.getAggOp(root));
    }
    CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
    tpl.setAggOps(aggOps);
    tpl.setSparseSafe(isSparseSafe(roots, sinHops[0], tpl.getOutputs(), tpl.getAggOps(), true));
    tpl.setRootNodes(roots);
    tpl.setBeginLine(hop.getBeginLine());
    // return cplan instance
    return new Pair<>(sinHops, tpl);
}
Also used : CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) HashMap(java.util.HashMap) AggOp(org.apache.sysml.hops.Hop.AggOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) CNodeMultiAgg(org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Example 19 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class PlanAnalyzer method getPartitionRootNodes.

private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, HashSet<Long> partition) {
    // build inverted index of references entries
    HashSet<Long> ix = new HashSet<>();
    for (Long hopID : partition) if (memo.contains(hopID))
        for (MemoTableEntry me : memo.get(hopID)) {
            ix.add(me.input1);
            ix.add(me.input2);
            ix.add(me.input3);
        }
    HashSet<Long> roots = new HashSet<>();
    for (Long hopID : partition) if (!ix.contains(hopID))
        roots.add(hopID);
    if (LOG.isTraceEnabled()) {
        LOG.trace("Partition root points: " + Arrays.toString(roots.toArray(new Long[0])));
    }
    return roots;
}
Also used : MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) HashSet(java.util.HashSet)

Example 20 with MemoTableEntry

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry in project systemml by apache.

the class PlanSelection method rSelectPlansFuseAll.

protected void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) {
    if (isVisited(current.getHopID(), currentType) || (partition != null && !partition.contains(current.getHopID())))
        return;
    // step 1: prune subsumed plans of same type
    if (memo.contains(current.getHopID())) {
        HashSet<MemoTableEntry> rmSet = new HashSet<>();
        List<MemoTableEntry> hopP = memo.get(current.getHopID());
        for (MemoTableEntry e1 : hopP) for (MemoTableEntry e2 : hopP) if (e1 != e2 && e1.subsumes(e2))
            rmSet.add(e2);
        memo.remove(current, rmSet);
    }
    // step 2: select plan for current path
    MemoTableEntry best = null;
    if (memo.contains(current.getHopID())) {
        if (currentType == null) {
            best = memo.get(current.getHopID()).stream().filter(p -> p.isValid()).min(BASE_COMPARE).orElse(null);
        } else {
            _typedCompare.setType(currentType);
            best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CELL).min(_typedCompare).orElse(null);
        }
        addBestPlan(current.getHopID(), best);
    }
    // step 3: recursively process children
    for (int i = 0; i < current.getInput().size(); i++) {
        TemplateType pref = (best != null && best.isPlanRef(i)) ? best.type : null;
        rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
    }
    setVisited(current.getHopID(), currentType);
}
Also used : HashSet(java.util.HashSet) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) List(java.util.List) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) HashMap(java.util.HashMap) UtilFunctions(org.apache.sysml.runtime.util.UtilFunctions) Comparator(java.util.Comparator) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) HashSet(java.util.HashSet)

Aggregations

MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)59 Hop (org.apache.sysml.hops.Hop)46 HashSet (java.util.HashSet)28 ArrayList (java.util.ArrayList)22 List (java.util.List)20 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)20 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)19 HashMap (java.util.HashMap)16 Comparator (java.util.Comparator)15 BinaryOp (org.apache.sysml.hops.BinaryOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TemplateType (org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType)15 Entry (java.util.Map.Entry)13 IndexingOp (org.apache.sysml.hops.IndexingOp)13 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 TernaryOp (org.apache.sysml.hops.TernaryOp)13 AggOp (org.apache.sysml.hops.Hop.AggOp)11 CPlanMemoTable (org.apache.sysml.hops.codegen.template.CPlanMemoTable)10 Arrays (java.util.Arrays)9