Search in sources :

Example 1 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rIsRowTemplateWithoutAgg.

private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
    if (visited.contains(current.getHopID()))
        return true;
    boolean ret = true;
    MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl);
    for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
        ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
    ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
    visited.add(current.getHopID());
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 2 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rIsRowTemplateWithoutAgg.

private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
    if (visited.contains(current.getHopID()))
        return true;
    boolean ret = true;
    MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
    for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
        ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
    ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
    visited.add(current.getHopID());
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 3 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class TemplateCell method constructCplan.

@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    // recursively process required cplan output
    HashSet<Hop> inHops = new HashSet<>();
    HashMap<Long, CNode> tmp = new HashMap<>();
    hop.resetVisitStatus();
    rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
    hop.resetVisitStatus();
    // reorder inputs (ensure matrices/vectors come first) and prune literals
    // note: we order by number of cells and subsequently sparsity to ensure
    // that sparse inputs are used as the main input w/o unnecessary conversion
    Hop[] sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).toArray(Hop[]::new);
    // construct template node
    ArrayList<CNode> inputs = new ArrayList<>();
    for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
    CNode output = tmp.get(hop.getHopID());
    CNodeCell tpl = new CNodeCell(inputs, output);
    tpl.setCellType(TemplateUtils.getCellType(hop));
    tpl.setAggOp(TemplateUtils.getAggOp(hop));
    tpl.setSparseSafe(isSparseSafe(Arrays.asList(hop), sinHops[0], Arrays.asList(tpl.getOutput()), Arrays.asList(tpl.getAggOp()), false));
    tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
    tpl.setBeginLine(hop.getBeginLine());
    // return cplan instance
    return new Pair<>(sinHops, tpl);
}
Also used : HashMap(java.util.HashMap) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CNodeCell(org.apache.sysml.hops.codegen.cplan.CNodeCell) CNode(org.apache.sysml.hops.codegen.cplan.CNode) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Example 4 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class TemplateCell method isSparseSafe.

protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) {
    boolean ret = true;
    for (int i = 0; i < outputs.size() && ret; i++) {
        Hop root = (roots.get(i) instanceof AggUnaryOp || roots.get(i) instanceof AggBinaryOp) ? roots.get(i).getInput().get(0) : roots.get(i);
        ret &= (HopRewriteUtils.isBinarySparseSafe(root) && root.getInput().contains(mainInput)) || (HopRewriteUtils.isBinary(root, OpOp2.DIV) && root.getInput().get(0) == mainInput) || (TemplateUtils.rIsSparseSafeOnly(outputs.get(i), BinType.MULT) && TemplateUtils.rContainsInput(outputs.get(i), mainInput.getHopID()));
        if (onlySum)
            ret &= (aggOps.get(i) == AggOp.SUM || aggOps.get(i) == AggOp.SUM_SQ);
    }
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop)

Example 5 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class TemplateCell method rConstructCplan.

protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
    // recursively process required childs
    if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
        CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
        tmp.put(hop.getHopID(), cdata);
        inHops.add(hop);
        return;
    }
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
            rConstructCplan(c, memo, tmp, inHops, compileLiterals);
        else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
        i == 0)
            rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        String primitiveOpName = ((UnaryOp) hop).getOp().name();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = bop.getOp().name();
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        // construct binary cnode
        out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        // construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
        // correct indexing types of existing lookups
        if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
            TemplateUtils.rFlipVectorLookups(out);
        // maintain input hops
        if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
            inHops.add(hop.getInput().get(0));
    } else if (hop instanceof AggUnaryOp) {
        // aggregation handled in template implementation (note: we do not compile
        // ^2 of SUM_SQ into the operator to simplify the detection of single operators)
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggBinaryOp) {
        // (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
        if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
            CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            out = new CNodeUnary(cdata1, UnaryType.POW2);
        } else {
            CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
            if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
                inHops.add(hop.getInput().get(0).getInput().get(0));
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (TemplateUtils.isColVector(cdata2))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
        }
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)56 Hop (org.apache.sysml.hops.Hop)47 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)32 BinaryOp (org.apache.sysml.hops.BinaryOp)27 LiteralOp (org.apache.sysml.hops.LiteralOp)21 ReorgOp (org.apache.sysml.hops.ReorgOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TernaryOp (org.apache.sysml.hops.TernaryOp)11 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)11 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)10 IndexingOp (org.apache.sysml.hops.IndexingOp)9 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)9 CNode (org.apache.sysml.hops.codegen.cplan.CNode)8 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)6 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)6 DataOp (org.apache.sysml.hops.DataOp)4 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4