Search in sources :

Example 1 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class TemplateCell method rConstructCplan.

protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
    //memoization for common subexpression elimination and to avoid redundant work 
    if (tmp.containsKey(hop.getHopID()))
        return;
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl);
    //recursively process required childs
    if (me != null && (me.type == TemplateType.RowTpl || me.type == TemplateType.OuterProdTpl)) {
        CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
        tmp.put(hop.getHopID(), cdata);
        inHops.add(hop);
        return;
    }
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl)))
            rConstructCplan(c, memo, tmp, inHops, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    //construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        String primitiveOpName = ((UnaryOp) hop).getOp().name();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = bop.getOp().name();
        //add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        if (bop.getOp() == OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2"))
            out = new CNodeUnary(cdata1, UnaryType.POW2);
        else if (bop.getOp() == OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2"))
            out = new CNodeUnary(cdata1, UnaryType.MULT2);
        else
            //default binary	
            out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        //add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        //construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggUnaryOp) {
        //aggregation handled in template implementation (note: we do not compile 
        //^2 of SUM_SQ into the operator to simplify the detection of single operators)
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggBinaryOp) {
        //(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
        if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
            CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
            out = new CNodeUnary(cdata1, UnaryType.POW2);
        } else {
            CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (TemplateUtils.isColVector(cdata2))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
        }
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 2 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class TemplateCell method constructCplan.

@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    //recursively process required cplan output
    HashSet<Hop> inHops = new HashSet<Hop>();
    HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
    hop.resetVisitStatus();
    rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
    hop.resetVisitStatus();
    //reorder inputs (ensure matrices/vectors come first) and prune literals
    //note: we order by number of cells and subsequently sparsity to ensure
    //that sparse inputs are used as the main input w/o unnecessary conversion
    List<Hop> sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).collect(Collectors.toList());
    //construct template node
    ArrayList<CNode> inputs = new ArrayList<CNode>();
    for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
    CNode output = tmp.get(hop.getHopID());
    CNodeCell tpl = new CNodeCell(inputs, output);
    tpl.setCellType(TemplateUtils.getCellType(hop));
    tpl.setAggOp(TemplateUtils.getAggOp(hop));
    tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops.get(0))) || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0)));
    tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
    // return cplan instance
    return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
Also used : HashMap(java.util.HashMap) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CNodeCell(org.apache.sysml.hops.codegen.cplan.CNodeCell) CNode(org.apache.sysml.hops.codegen.cplan.CNode) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Example 3 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class TemplateCell method isValidOperation.

protected static boolean isValidOperation(Hop hop) {
    //prepare indicators for binary operations
    boolean isBinaryMatrixScalar = false;
    boolean isBinaryMatrixVector = false;
    boolean isBinaryMatrixMatrixDense = false;
    if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        DataType ldt = left.getDataType();
        DataType rdt = right.getDataType();
        isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar());
        isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)));
        isBinaryMatrixMatrixDense = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix() && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
    }
    //prepare indicators for ternary operations
    boolean isTernaryVectorScalarVector = false;
    boolean isTernaryMatrixScalarMatrixDense = false;
    if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(2);
        isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
        isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
    }
    //check supported unary, binary, ternary operations
    return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.REPLACE));
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) DataType(org.apache.sysml.parser.Expression.DataType) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 4 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class TemplateOuterProduct method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    //memoization for common subexpression elimination and to avoid redundant work 
    if (tmp.containsKey(hop.getHopID()))
        return;
    //recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OuterProdTpl);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    //construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        String primitiveOpName = ((UnaryOp) hop).getOp().toString();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = ((BinaryOp) hop).getOp().toString();
        if (HopRewriteUtils.isEqualSize(hop.getInput().get(0), hop.getInput().get(1))) {
            Hop main = hop.getInput().get((cdata1 instanceof CNodeData) ? 0 : 1);
            inHops2.put("_X", main);
        }
        //add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        //handle transpose in outer or final product
        cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
        cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
        //outer product U%*%t(V), see open
        if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
            //keep U and V for later reference
            inHops2.put("_U", hop.getInput().get(0));
            if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
                inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
            else
                inHops2.put("_V", hop.getInput().get(1));
            out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
        } else //final left/right matrix mult, see close
        {
            if (cdata1.getDataType().isScalar())
                out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
            else
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) BinaryOp(org.apache.sysml.hops.BinaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 5 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class TemplateOuterProduct method constructCplan.

public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    //recursively process required cplan output
    HashSet<Hop> inHops = new HashSet<Hop>();
    HashMap<String, Hop> inHops2 = new HashMap<String, Hop>();
    HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
    hop.resetVisitStatus();
    rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
    hop.resetVisitStatus();
    //reorder inputs (ensure matrix is first input)
    Hop X = inHops2.get("_X");
    Hop U = inHops2.get("_U");
    Hop V = inHops2.get("_V");
    LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops);
    sinHops.remove(V);
    sinHops.addFirst(V);
    sinHops.remove(U);
    sinHops.addFirst(U);
    sinHops.remove(X);
    sinHops.addFirst(X);
    //construct template node
    ArrayList<CNode> inputs = new ArrayList<CNode>();
    for (Hop in : sinHops) if (in != null)
        inputs.add(tmp.get(in.getHopID()));
    CNode output = tmp.get(hop.getHopID());
    CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output);
    tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop));
    tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) && tpl.getOutProdType() == OutProdType.LEFT_OUTER_PRODUCT);
    return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
Also used : HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeOuterProduct(org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Aggregations

Hop (org.apache.sysml.hops.Hop)230 LiteralOp (org.apache.sysml.hops.LiteralOp)75 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)55 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)54 BinaryOp (org.apache.sysml.hops.BinaryOp)51 ArrayList (java.util.ArrayList)49 UnaryOp (org.apache.sysml.hops.UnaryOp)38 DataOp (org.apache.sysml.hops.DataOp)37 HashMap (java.util.HashMap)29 ReorgOp (org.apache.sysml.hops.ReorgOp)28 StatementBlock (org.apache.sysml.parser.StatementBlock)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HashSet (java.util.HashSet)20 DataGenOp (org.apache.sysml.hops.DataGenOp)20 HopsException (org.apache.sysml.hops.HopsException)19 IndexingOp (org.apache.sysml.hops.IndexingOp)19 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)18 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)18 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)18 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)15