Search in sources :

Example 1 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class TemplateCell method isValidOperation.

protected static boolean isValidOperation(Hop hop) {
    // prepare indicators for binary operations
    boolean isBinaryMatrixScalar = false;
    boolean isBinaryMatrixVector = false;
    boolean isBinaryMatrixMatrix = false;
    if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        DataType ldt = left.getDataType();
        DataType rdt = right.getDataType();
        isBinaryMatrixScalar = (ldt.isScalar() || rdt.isScalar());
        isBinaryMatrixVector = hop.dimsKnown() && ((ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right)) || (rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left)));
        isBinaryMatrixMatrix = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix();
    }
    // prepare indicators for ternary operations
    boolean isTernaryVectorScalarVector = false;
    boolean isTernaryMatrixScalarMatrixDense = false;
    boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix());
    if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(2);
        isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
        isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
    }
    // check supported unary, binary, ternary operations
    return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.REPLACE));
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) DataType(org.apache.sysml.parser.Expression.DataType) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 2 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class TemplateCell method rConstructCplan.

protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CELL);
    // recursively process required childs
    if (me != null && me.type.isIn(TemplateType.ROW, TemplateType.OUTER)) {
        CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
        tmp.put(hop.getHopID(), cdata);
        inHops.add(hop);
        return;
    }
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MAGG || memo.contains(c.getHopID(), TemplateType.CELL)))
            rConstructCplan(c, memo, tmp, inHops, compileLiterals);
        else if (me != null && (me.type == TemplateType.MAGG || me.type == TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && // skip transpose
        i == 0)
            rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        String primitiveOpName = ((UnaryOp) hop).getOp().name();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = bop.getOp().name();
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        // construct binary cnode
        out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        // construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
        // correct indexing types of existing lookups
        if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class))
            TemplateUtils.rFlipVectorLookups(out);
        // maintain input hops
        if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
            inHops.add(hop.getInput().get(0));
    } else if (hop instanceof AggUnaryOp) {
        // aggregation handled in template implementation (note: we do not compile
        // ^2 of SUM_SQ into the operator to simplify the detection of single operators)
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggBinaryOp) {
        // (1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
        if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
            CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            out = new CNodeUnary(cdata1, UnaryType.POW2);
        } else {
            CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
            if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)))
                inHops.add(hop.getInput().get(0).getInput().get(0));
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (TemplateUtils.isColVector(cdata2))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
        }
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 3 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class HopRewriteUtils method createTernaryOp.

public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) {
    TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright);
    ternOp.setOutputBlocksizes(mleft.getRowsInBlock(), mleft.getColsInBlock());
    copyLineNumbers(mleft, ternOp);
    ternOp.refreshSizeInformation();
    return ternOp;
}
Also used : TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 4 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyReverseOperation.

/**
 * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static
 * rewrite in order to apply it before splitting dags which would hide the table information
 * if dimensions are not specified.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyReverseOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggBinaryOp && hi.getInput().get(0) instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hi.getInput().get(0);
        if (top.getOp() == OpOp3.CTABLE && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) && top.getInput().get(0).getDim1() == top.getInput().get(1).getDim1()) {
            ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
            HopRewriteUtils.replaceChildReference(parent, hi, rop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, top);
            hi = rop;
            LOG.debug("Applied simplifyReverseOperation.");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) ReorgOp(org.apache.sysml.hops.ReorgOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 5 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project systemml by apache.

the class HopRewriteUtils method createTernaryOp.

public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) {
    TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright);
    ternOp.setOutputBlocksizes(mleft.getRowsInBlock(), mleft.getColsInBlock());
    copyLineNumbers(mleft, ternOp);
    ternOp.refreshSizeInformation();
    return ternOp;
}
Also used : TernaryOp(org.apache.sysml.hops.TernaryOp)

Aggregations

TernaryOp (org.apache.sysml.hops.TernaryOp)19 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)15 Hop (org.apache.sysml.hops.Hop)15 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)11 BinaryOp (org.apache.sysml.hops.BinaryOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 ReorgOp (org.apache.sysml.hops.ReorgOp)7 HashMap (java.util.HashMap)4 DataOp (org.apache.sysml.hops.DataOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)4 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)4 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)4 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)4 ArrayList (java.util.ArrayList)2