Search in sources :

Example 1 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class SpoofCompiler method constructCPlans.

////////////////////
// Codegen plan construction
private static HashMap<Long, Pair<Hop[], CNodeTpl>> constructCPlans(ArrayList<Hop> roots, boolean compileLiterals) throws DMLException {
    //explore cplan candidates
    CPlanMemoTable memo = new CPlanMemoTable();
    for (Hop hop : roots) rExploreCPlans(hop, memo, compileLiterals);
    //select optimal cplan candidates
    memo.pruneSuboptimal(roots);
    //construct actual cplan representations
    //note: we do not use the hop visit status due to jumps over fused operators which would
    //corrupt subsequent resets, leaving partial hops dags in visited status
    LinkedHashMap<Long, Pair<Hop[], CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>();
    HashSet<Long> visited = new HashSet<Long>();
    for (Hop hop : roots) rConstructCPlans(hop, memo, ret, compileLiterals, visited);
    return ret;
}
Also used : CNodeTpl(org.apache.sysml.hops.codegen.cplan.CNodeTpl) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) Hop(org.apache.sysml.hops.Hop) LinkedHashMap(java.util.LinkedHashMap) Pair(org.apache.sysml.runtime.matrix.data.Pair) HashSet(java.util.HashSet)

Example 2 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rGetPlanCosts.

private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Pair<Long, Long>> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType) {
    // memoization per hop id and cost vector to account for redundant
    // computation without double counting materialized results or compute
    // costs of complex operation DAGs within a single fused operator
    Pair<Long, Long> tag = Pair.of(current.getHopID(), (costsCurrent == null) ? 0 : costsCurrent.ID);
    if (visited.contains(tag))
        return 0;
    visited.add(tag);
    // open template if necessary, including memoization
    // under awareness of current plan choice
    MemoTableEntry best = null;
    boolean opened = false;
    if (memo.contains(current.getHopID())) {
        if (currentType == null) {
            best = memo.get(current.getHopID()).stream().filter(p -> p.isValid()).filter(p -> hasNoRefToMaterialization(p, M, plan)).min(new BasicPlanComparator()).orElse(null);
            opened = true;
        } else {
            best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CELL).filter(p -> hasNoRefToMaterialization(p, M, plan)).min(Comparator.comparing(p -> 7 - ((p.type == currentType) ? 4 : 0) - p.countPlanRefs())).orElse(null);
        }
    }
    // create new cost vector if opened, initialized with write costs
    CostVector costVect = !opened ? costsCurrent : new CostVector(Math.max(current.getDim1(), 1) * Math.max(current.getDim2(), 1));
    // add compute costs of current operator to costs vector
    if (partition.contains(current.getHopID()))
        costVect.computeCosts += computeCosts.get(current.getHopID());
    // process children recursively
    double costs = 0;
    for (int i = 0; i < current.getInput().size(); i++) {
        Hop c = current.getInput().get(i);
        if (best != null && best.isPlanRef(i))
            costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, costVect, best.type);
        else if (best != null && isImplicitlyFused(current, i, best.type))
            costVect.addInputSize(c.getInput().get(0).getHopID(), Math.max(c.getDim1(), 1) * Math.max(c.getDim2(), 1));
        else {
            // include children and I/O costs
            costs += rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, null, null);
            if (costVect != null && c.getDataType().isMatrix())
                costVect.addInputSize(c.getHopID(), Math.max(c.getDim1(), 1) * Math.max(c.getDim2(), 1));
        }
    }
    // add costs for opened fused operator
    if (partition.contains(current.getHopID())) {
        if (opened) {
            if (LOG.isTraceEnabled())
                LOG.trace("Cost vector for fused operator (hop " + current.getHopID() + "): " + costVect);
            // time for output write
            costs += costVect.outSize * 8 / WRITE_BANDWIDTH;
            costs += Math.max(costVect.computeCosts * costVect.getMaxInputSize() / COMPUTE_BANDWIDTH, costVect.getSumInputSizes() * 8 / READ_BANDWIDTH);
        } else // add costs for non-partition read in the middle of fused operator
        if (hasNonPartitionConsumer(current, partition)) {
            costs += rGetPlanCosts(memo, current, visited, partition, M, plan, computeCosts, null, null);
        }
    }
    // sanity check non-negative costs
    if (costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs))
        throw new RuntimeException("Wrong cost estimate: " + costs);
    return costs;
}
Also used : TemplateRow(org.apache.sysml.hops.codegen.template.TemplateRow) Arrays(java.util.Arrays) IndexingOp(org.apache.sysml.hops.IndexingOp) HashMap(java.util.HashMap) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) TemplateOuterProduct(org.apache.sysml.hops.codegen.template.TemplateOuterProduct) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggOp(org.apache.sysml.hops.Hop.AggOp) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) HashSet(java.util.HashSet) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Pair(org.apache.commons.lang3.tuple.Pair) ReorgOp(org.apache.sysml.hops.ReorgOp) IDSequence(org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence) CollectionUtils(org.apache.commons.collections.CollectionUtils) InfrastructureAnalyzer(org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Statistics(org.apache.sysml.utils.Statistics) TernaryOp(org.apache.sysml.hops.TernaryOp) Iterator(java.util.Iterator) Collection(java.util.Collection) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) BinaryOp(org.apache.sysml.hops.BinaryOp) TemplateUtils(org.apache.sysml.hops.codegen.template.TemplateUtils) Collectors(java.util.stream.Collectors) Direction(org.apache.sysml.hops.Hop.Direction) Hop(org.apache.sysml.hops.Hop) List(java.util.List) Entry(java.util.Map.Entry) DMLScript(org.apache.sysml.api.DMLScript) Log(org.apache.commons.logging.Log) LogFactory(org.apache.commons.logging.LogFactory) UtilFunctions(org.apache.sysml.runtime.util.UtilFunctions) Comparator(java.util.Comparator) Collections(java.util.Collections) HopRewriteUtils(org.apache.sysml.hops.rewrite.HopRewriteUtils) UnaryOp(org.apache.sysml.hops.UnaryOp) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop)

Example 3 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class PlanSelectionFuseCostBasedV2 method createAndAddMultiAggPlans.

// across-partition multi-agg templates with shared reads
private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
    // collect full aggregations as initial set of candidates
    HashSet<Long> fullAggs = new HashSet<>();
    Hop.resetVisitStatus(roots);
    for (Hop hop : roots) rCollectFullAggregates(hop, fullAggs);
    Hop.resetVisitStatus(roots);
    // remove operators with assigned multi-agg plans
    fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
    // check applicability for further analysis
    if (fullAggs.size() <= 1)
        return;
    if (LOG.isTraceEnabled()) {
        LOG.trace("Found across-partition ua(RC) aggregations: " + Arrays.toString(fullAggs.toArray(new Long[0])));
    }
    // collect information for all candidates
    // (subsumed aggregations, and inputs to fused operators)
    List<AggregateInfo> aggInfos = new ArrayList<>();
    for (Long hopID : fullAggs) {
        Hop aggHop = memo.getHopRefs().get(hopID);
        AggregateInfo tmp = new AggregateInfo(aggHop);
        for (int i = 0; i < aggHop.getInput().size(); i++) {
            Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i == 0 ? aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
            rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
        }
        if (tmp._fusedInputs.isEmpty()) {
            if (HopRewriteUtils.isMatrixMultiply(aggHop)) {
                tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
                tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
            } else
                tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
        }
        aggInfos.add(tmp);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
        for (AggregateInfo info : aggInfos) LOG.trace(info);
    }
    // sort aggregations by num dependencies to simplify merging
    // clusters of aggregations with parallel dependencies
    aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
    // greedy grouping of multi-agg candidates
    boolean converged = false;
    while (!converged) {
        AggregateInfo merged = null;
        for (int i = 0; i < aggInfos.size(); i++) {
            AggregateInfo current = aggInfos.get(i);
            for (int j = i + 1; j < aggInfos.size(); j++) {
                AggregateInfo that = aggInfos.get(j);
                if (current.isMergable(that)) {
                    merged = current.merge(that);
                    aggInfos.remove(j);
                    j--;
                }
            }
        }
        converged = (merged == null);
    }
    if (LOG.isTraceEnabled()) {
        LOG.trace("Merged across-partition ua(RC) aggregation info: ");
        for (AggregateInfo info : aggInfos) LOG.trace(info);
    }
    // construct and add multiagg template plans (w/ max 3 aggregations)
    for (AggregateInfo info : aggInfos) {
        if (info._aggregates.size() <= 1)
            continue;
        Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
        MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG, aggs[0], aggs[1], (aggs.length > 2) ? aggs[2] : -1, aggs.length);
        for (int i = 0; i < aggs.length; i++) {
            memo.add(memo.getHopRefs().get(aggs[i]), me);
            addBestPlan(aggs[i], me);
            if (LOG.isTraceEnabled())
                LOG.trace("Added multiagg* plan: " + aggs[i] + " " + me);
        }
    }
}
Also used : Arrays(java.util.Arrays) IndexingOp(org.apache.sysml.hops.IndexingOp) AggOp(org.apache.sysml.hops.Hop.AggOp) Pair(org.apache.commons.lang3.tuple.Pair) ReorgOp(org.apache.sysml.hops.ReorgOp) InfrastructureAnalyzer(org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp) Collection(java.util.Collection) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) BinaryOp(org.apache.sysml.hops.BinaryOp) TemplateUtils(org.apache.sysml.hops.codegen.template.TemplateUtils) Collectors(java.util.stream.Collectors) Direction(org.apache.sysml.hops.Hop.Direction) Hop(org.apache.sysml.hops.Hop) List(java.util.List) Entry(java.util.Map.Entry) LogFactory(org.apache.commons.logging.LogFactory) TemplateRow(org.apache.sysml.hops.codegen.template.TemplateRow) HashMap(java.util.HashMap) ArrayUtils(org.apache.commons.lang3.ArrayUtils) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) TemplateOuterProduct(org.apache.sysml.hops.codegen.template.TemplateOuterProduct) DataOpTypes(org.apache.sysml.hops.Hop.DataOpTypes) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) HashSet(java.util.HashSet) LinkedHashMap(java.util.LinkedHashMap) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) IDSequence(org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence) CollectionUtils(org.apache.commons.collections.CollectionUtils) Statistics(org.apache.sysml.utils.Statistics) SubProblem(org.apache.sysml.hops.codegen.opt.ReachabilityGraph.SubProblem) OpOpN(org.apache.sysml.hops.Hop.OpOpN) RUNTIME_PLATFORM(org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM) Iterator(java.util.Iterator) LibSpoofPrimitives(org.apache.sysml.runtime.codegen.LibSpoofPrimitives) LazyWriteBuffer(org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer) OptimizerUtils(org.apache.sysml.hops.OptimizerUtils) ReOrgOp(org.apache.sysml.hops.Hop.ReOrgOp) DMLScript(org.apache.sysml.api.DMLScript) OpOp2(org.apache.sysml.hops.Hop.OpOp2) Log(org.apache.commons.logging.Log) UtilFunctions(org.apache.sysml.runtime.util.UtilFunctions) Comparator(java.util.Comparator) Collections(java.util.Collections) HopRewriteUtils(org.apache.sysml.hops.rewrite.HopRewriteUtils) UnaryOp(org.apache.sysml.hops.UnaryOp) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet)

Example 4 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class PlanSelectionFuseNoRedundancy method rSelectPlans.

private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) {
    if (isVisited(current.getHopID(), currentType))
        return;
    // step 0: remove plans that refer to a common partial plan
    if (memo.contains(current.getHopID())) {
        HashSet<MemoTableEntry> rmSet = new HashSet<>();
        List<MemoTableEntry> hopP = memo.get(current.getHopID());
        for (MemoTableEntry e1 : hopP) for (int i = 0; i < 3; i++) if (e1.isPlanRef(i) && current.getInput().get(i).getParent().size() > 1)
            // remove references to hops w/ multiple consumers
            rmSet.add(e1);
        memo.remove(current, rmSet);
    }
    // step 1: prune subsumed plans of same type
    if (memo.contains(current.getHopID())) {
        HashSet<MemoTableEntry> rmSet = new HashSet<>();
        List<MemoTableEntry> hopP = memo.get(current.getHopID());
        for (MemoTableEntry e1 : hopP) for (MemoTableEntry e2 : hopP) if (e1 != e2 && e1.subsumes(e2))
            rmSet.add(e2);
        memo.remove(current, rmSet);
    }
    // step 2: select plan for current path
    MemoTableEntry best = null;
    if (memo.contains(current.getHopID())) {
        if (currentType == null) {
            best = memo.get(current.getHopID()).stream().filter(p -> p.isValid()).min(new BasicPlanComparator()).orElse(null);
        } else {
            best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CELL).min(Comparator.comparing(p -> 7 - ((p.type == currentType) ? 4 : 0) - p.countPlanRefs())).orElse(null);
        }
        addBestPlan(current.getHopID(), best);
    }
    // step 3: recursively process children
    for (int i = 0; i < current.getInput().size(); i++) {
        TemplateType pref = (best != null && best.isPlanRef(i)) ? best.type : null;
        rSelectPlans(memo, current.getInput().get(i), pref);
    }
    setVisited(current.getHopID(), currentType);
}
Also used : HashSet(java.util.HashSet) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) List(java.util.List) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) Entry(java.util.Map.Entry) Comparator(java.util.Comparator) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) HashSet(java.util.HashSet)

Example 5 with CPlanMemoTable

use of org.apache.sysml.hops.codegen.template.CPlanMemoTable in project incubator-systemml by apache.

the class PlanSelection method rSelectPlansFuseAll.

protected void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) {
    if (isVisited(current.getHopID(), currentType) || (partition != null && !partition.contains(current.getHopID())))
        return;
    // step 1: prune subsumed plans of same type
    if (memo.contains(current.getHopID())) {
        HashSet<MemoTableEntry> rmSet = new HashSet<>();
        List<MemoTableEntry> hopP = memo.get(current.getHopID());
        for (MemoTableEntry e1 : hopP) for (MemoTableEntry e2 : hopP) if (e1 != e2 && e1.subsumes(e2))
            rmSet.add(e2);
        memo.remove(current, rmSet);
    }
    // step 2: select plan for current path
    MemoTableEntry best = null;
    if (memo.contains(current.getHopID())) {
        if (currentType == null) {
            best = memo.get(current.getHopID()).stream().filter(p -> p.isValid()).min(BASE_COMPARE).orElse(null);
        } else {
            _typedCompare.setType(currentType);
            best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateType.CELL).min(_typedCompare).orElse(null);
        }
        addBestPlan(current.getHopID(), best);
    }
    // step 3: recursively process children
    for (int i = 0; i < current.getInput().size(); i++) {
        TemplateType pref = (best != null && best.isPlanRef(i)) ? best.type : null;
        rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
    }
    setVisited(current.getHopID(), currentType);
}
Also used : HashSet(java.util.HashSet) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) List(java.util.List) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) HashMap(java.util.HashMap) UtilFunctions(org.apache.sysml.runtime.util.UtilFunctions) Comparator(java.util.Comparator) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CPlanMemoTable(org.apache.sysml.hops.codegen.template.CPlanMemoTable) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) TemplateType(org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType) HashSet(java.util.HashSet)

Aggregations

HashSet (java.util.HashSet)7 Hop (org.apache.sysml.hops.Hop)7 CPlanMemoTable (org.apache.sysml.hops.codegen.template.CPlanMemoTable)7 ArrayList (java.util.ArrayList)5 Comparator (java.util.Comparator)5 HashMap (java.util.HashMap)5 List (java.util.List)5 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)5 TemplateType (org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType)5 Entry (java.util.Map.Entry)4 UtilFunctions (org.apache.sysml.runtime.util.UtilFunctions)4 Arrays (java.util.Arrays)3 Collection (java.util.Collection)3 Collections (java.util.Collections)3 Iterator (java.util.Iterator)3 LinkedHashMap (java.util.LinkedHashMap)3 Collectors (java.util.stream.Collectors)3 CollectionUtils (org.apache.commons.collections.CollectionUtils)3 Pair (org.apache.commons.lang3.tuple.Pair)3 Log (org.apache.commons.logging.Log)3