Search in sources :

Example 46 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.

the class TemplateOuterProduct method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    // recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OUTER, TemplateType.CELL);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        String primitiveOpName = ((UnaryOp) hop).getOp().toString();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = ((BinaryOp) hop).getOp().toString();
        if (HopRewriteUtils.isBinarySparseSafe(hop)) {
            if (TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData)
                inHops2.put("_X", hop.getInput().get(0));
            if (TemplateUtils.isMatrix(hop.getInput().get(1)) && cdata2 instanceof CNodeData)
                inHops2.put("_X", hop.getInput().get(1));
        }
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        // handle transpose in outer or final product
        cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
        cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
        // outer product U%*%t(V), see open
        if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
            // keep U and V for later reference
            inHops2.put("_U", hop.getInput().get(0));
            if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
                inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
            else
                inHops2.put("_V", hop.getInput().get(1));
            out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
        } else // final left/right matrix mult, see close
        {
            if (cdata1.getDataType().isScalar())
                out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
            else
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) BinaryOp(org.apache.sysml.hops.BinaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 47 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.

the class TemplateRow method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    // recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof AggUnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
            if (hop.getInput().get(0).getDim2() == 1)
                out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            else {
                String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            }
        } else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            // vector add without temporary copy
            if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
                out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
            else
                out = cdata1;
        } else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
        }
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
            // correct input under transpose
            cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
            inHops.remove(hop.getInput().get(0));
            if (cdata1 instanceof CNodeData)
                inHops.add(hop.getInput().get(0).getInput().get(0));
            // note: vectorMultAdd applicable to vector-scalar, and vector-vector
            if (hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
            else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
                if (!inHops2.containsKey("B1")) {
                    // incl modification of X for consistency
                    if (cdata1 instanceof CNodeData)
                        inHops2.put("X", hop.getInput().get(0).getInput().get(0));
                    inHops2.put("B1", hop.getInput().get(1));
                }
            }
            if (!inHops2.containsKey("X"))
                inHops2.put("X", hop.getInput().get(0).getInput().get(0));
        } else {
            if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
            else if (hop.getInput().get(1).getDim2() == 1) {
                out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
                inHops2.put("X", hop.getInput().get(0));
            } else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
                inHops2.put("X", hop.getInput().get(0));
                inHops2.put("B1", hop.getInput().get(1));
            }
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
        if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
            inHops.add(hop.getInput().get(0));
    } else if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
            if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
                String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            } else
                throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
        } else // general scalar case
        {
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            String primitiveOpName = ((UnaryOp) hop).getOp().toString();
            out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
        }
    } else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
        // special case for cbind with zeros
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = null;
        if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
            cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
            // rm 0-matrix
            inHops.remove(hop.getInput().get(1));
        } else {
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        }
        out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
        if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
            inHops2.put("X", hop.getInput().get(0));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
        (hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
            if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
                if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                } else {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
                    if (TemplateUtils.isColVector(cdata1))
                        cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
                    if (TemplateUtils.isColVector(cdata2))
                        cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                }
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else
                throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
        } else // one input is a vector/scalar other is a scalar
        {
            String primitiveOpName = ((BinaryOp) hop).getOp().toString();
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            if (// vector or vector can be inferred from lhs
            TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
        }
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        // construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
    } else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
        CNode[] inputs = new CNode[hop.getInput().size()];
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop c = hop.getInput().get(i);
            CNode cdata = tmp.get(c.getHopID());
            if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
                cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
            inputs[i] = cdata;
            if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
                inHops2.put("X", c);
        }
        out = new CNodeNary(inputs, NaryType.VECT_CBIND);
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
    }
    if (out == null) {
        throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
    }
    if (out.getDataType().isMatrix()) {
        out.setNumRows(hop.getDim1());
        out.setNumCols(hop.getDim2());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) CNodeNary(org.apache.sysml.hops.codegen.cplan.CNodeNary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 48 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifySumMatrixMult.

private static Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) {
    // -- if sum not the only consumer, not applied to prevent redundancy
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // A%*%B
    hi.getInput().get(0) instanceof AggBinaryOp && // not dot product
    (hi.getInput().get(0).getDim1() > 1 || hi.getInput().get(0).getDim2() > 1) && // not multiple consumers of matrix mult
    hi.getInput().get(0).getParent().size() == 1) {
        Hop hi2 = hi.getInput().get(0);
        Hop left = hi2.getInput().get(0);
        Hop right = hi2.getInput().get(1);
        // remove link from parent to matrix mult
        HopRewriteUtils.removeChildReference(hi, hi2);
        // create new operators
        Hop root = null;
        // pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
        if (((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
            LOG.debug("Applied simplifySumMatrixMult RC.");
        } else // colSums(A%*%B) -> colSums(A)%*%B
        if (((AggUnaryOp) hi).getDirection() == Direction.Col) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            root = HopRewriteUtils.createMatrixMultiply(colSum, right);
            LOG.debug("Applied simplifySumMatrixMult C.");
        } else // rowSums(A%*%B) -> A%*%rowSums(B)
        if (((AggUnaryOp) hi).getDirection() == Direction.Row) {
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
            LOG.debug("Applied simplifySumMatrixMult R.");
        }
        // rehang new subdag under current node (keep hi intact)
        HopRewriteUtils.addChildReference(hi, root, 0);
        hi.refreshSizeInformation();
        // cleanup if only consumer of intermediate
        HopRewriteUtils.cleanupUnreferenced(hi2);
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 49 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyReverseOperation.

/**
 * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static
 * rewrite in order to apply it before splitting dags which would hide the table information
 * if dimensions are not specified.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyReverseOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggBinaryOp && hi.getInput().get(0) instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hi.getInput().get(0);
        if (top.getOp() == OpOp3.CTABLE && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) && top.getInput().get(0).getDim1() == top.getInput().get(1).getDim1()) {
            ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
            HopRewriteUtils.replaceChildReference(parent, hi, rop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, top);
            hi = rop;
            LOG.debug("Applied simplifyReverseOperation.");
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) ReorgOp(org.apache.sysml.hops.ReorgOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 50 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyBushyBinaryOperation.

/**
 * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
 * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
 *
 * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
 * eagerly, which would loose additional rewrite potential. This rewrite has two goals
 * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyBushyBinaryOperation(Hop parent, Hop hi, int pos) {
    if (hi instanceof BinaryOp && parent instanceof AggBinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        OpOp2 op = bop.getOp();
        if (left.getDataType() == DataType.MATRIX && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY)) {
            boolean applied = false;
            if (right instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) right;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && right2.getDataType() == DataType.MATRIX && (right2 instanceof AggBinaryOp)) {
                    // (X*(Y*op()) -> (X*Y)*op()
                    BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    applied = true;
                    LOG.debug("Applied simplifyBushyBinaryOperation1");
                }
            }
            if (!applied && left instanceof BinaryOp) {
                BinaryOp bop2 = (BinaryOp) left;
                Hop left2 = bop2.getInput().get(0);
                Hop right2 = bop2.getInput().get(1);
                OpOp2 op2 = bop2.getOp();
                if (op == op2 && left2.getDataType() == DataType.MATRIX && (left2 instanceof AggBinaryOp) && // X not vector, or Y vector
                (right2.getDim2() > 1 || right.getDim2() == 1) && // X not vector, or Y vector
                (right2.getDim1() > 1 || right.getDim1() == 1)) {
                    // ((op()*X)*Y) -> op()*(X*Y)
                    BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op);
                    BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op);
                    HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos);
                    HopRewriteUtils.cleanupUnreferenced(bop, bop2);
                    hi = bop4;
                    LOG.debug("Applied simplifyBushyBinaryOperation2");
                }
            }
        }
    }
    return hi;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)56 Hop (org.apache.sysml.hops.Hop)47 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)32 BinaryOp (org.apache.sysml.hops.BinaryOp)27 LiteralOp (org.apache.sysml.hops.LiteralOp)21 ReorgOp (org.apache.sysml.hops.ReorgOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TernaryOp (org.apache.sysml.hops.TernaryOp)11 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)11 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)10 IndexingOp (org.apache.sysml.hops.IndexingOp)9 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)9 CNode (org.apache.sysml.hops.codegen.cplan.CNode)8 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)6 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)6 DataOp (org.apache.sysml.hops.DataOp)4 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4