Search in sources :

Example 6 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class IPAPassRemoveConstantBinaryOps method rRemoveConstantBinaryOp.

private static void rRemoveConstantBinaryOp(Hop hop, HashMap<String, Hop> mOnes) {
    if (hop.isVisited())
        return;
    if (hop instanceof BinaryOp && ((BinaryOp) hop).getOp() == OpOp2.MULT && !((BinaryOp) hop).isOuterVectorOperator() && hop.getInput().get(0).getDataType() == DataType.MATRIX && hop.getInput().get(1) instanceof DataOp && mOnes.containsKey(hop.getInput().get(1).getName())) {
        // replace matrix of ones with literal 1 (later on removed by
        // algebraic simplification rewrites; otherwise more complex
        // recursive processing of childs and rewiring required)
        HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1);
        HopRewriteUtils.addChildReference(hop, new LiteralOp(1), 1);
    }
    // recursively process child nodes
    for (Hop c : hop.getInput()) rRemoveConstantBinaryOp(c, mOnes);
    hop.setVisited();
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 7 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class HopRewriteUtils method isNotMatrixVectorBinaryOperation.

public static boolean isNotMatrixVectorBinaryOperation(Hop hop) {
    boolean ret = true;
    if (hop instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hop;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        boolean mv = (left.getDim1() > 1 && right.getDim1() == 1) || (left.getDim2() > 1 && right.getDim2() == 1);
        ret = isDimsKnown(bop) && !mv;
    }
    return ret;
}
Also used : Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 8 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class HopRewriteUtils method isBinarySparseSafe.

public static boolean isBinarySparseSafe(Hop hop) {
    if (!(hop instanceof BinaryOp))
        return false;
    if (isBinary(hop, OpOp2.MULT, OpOp2.DIV))
        return true;
    BinaryOp bop = (BinaryOp) hop;
    Hop lit = bop.getInput().get(0) instanceof LiteralOp ? bop.getInput().get(0) : bop.getInput().get(1) instanceof LiteralOp ? bop.getInput().get(1) : null;
    return lit != null && OptimizerUtils.isBinaryOpSparsityConditionalSparseSafe(bop.getOp(), (LiteralOp) lit);
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 9 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.

/**
 * Searches for weighted squared loss expressions and replaces them with a quaternary operator.
 * Currently, this search includes the following three patterns:
 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
 *
 * NOTE: We include transpose into the pattern because during runtime we need to compute
 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
 * without additional memory requirements for internal transpose.
 *
 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
 * as we generalized the runtime or hop/lop compilation.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
    // NOTE: there might be also a general simplification without custom operator
    // via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
    Hop hnew = null;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        boolean appliedPattern = false;
        // alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
        if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
            Hop W = bop.getInput().get(0);
            // (X - U %*% t(V))
            Hop tmp = bop.getInput().get(1).getInput().get(0);
            if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
            HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
                // a) sum (W * (X - U %*% t(V)) ^ 2)
                int uvIndex = -1;
                if (// ba gurantees matrices
                tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    uvIndex = 1;
                } else // b) sum (W * (U %*% t(V) - X) ^ 2)
                if (// ba gurantees matrices
                tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
                    uvIndex = 0;
                }
                if (// rewrite match
                uvIndex >= 0) {
                    Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
                    Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
                    Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    // handle special case of post_nz
                    if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
                        W = new LiteralOp(1);
                    }
                    // construct quaternary hop
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - W * (U %*% t(V))) ^ 2)
            int wuvIndex = -1;
            if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 1;
            } else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
            if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 0;
            }
            if (// rewrite match
            wuvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
                // (W * (U %*% t(V)))
                Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
                if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
                HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    Hop W = tmp.getInput().get(0);
                    Hop U = tmp.getInput().get(1).getInput().get(0);
                    Hop V = tmp.getInput().get(1).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - (U %*% t(V))) ^ 2)
            int uvIndex = -1;
            if (// ba gurantees matrices
            lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
                uvIndex = 1;
            } else // b) sum (((U %*% t(V)) - X) ^ 2)
            if (// ba gurantees matrices
            lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
                uvIndex = 0;
            }
            if (// rewrite match
            uvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
                // (U %*% t(V))
                Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
                // no weighting
                Hop W = new LiteralOp(1);
                Hop U = tmp.getInput().get(0);
                Hop V = tmp.getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(V)) {
                    V = HopRewriteUtils.createTranspose(V);
                } else {
                    V = V.getInput().get(0);
                }
                hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                HopRewriteUtils.setOutputParametersForScalar(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 10 with BinaryOp

use of org.apache.sysml.hops.BinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method fuseAxpyBinaryOperationChain.

private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) {
    // patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
    if (hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp) hi).getOp() == OpOp2.PLUS || ((BinaryOp) hi).getOp() == OpOp2.MINUS)) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        Hop ternop = null;
        // pattern (a) X + s*Y -> X +* sY
        if (bop.getOp() == OpOp2.PLUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        right.getParent().size() == 1) {
            Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " + hi.getBeginLine() + ")");
        } else // pattern (b) s*Y + X -> X +* sY
        if (bop.getOp() == OpOp2.PLUS && right.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(left) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        left.getParent().size() == 1) {
            Hop smid = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = left.getInput().get((left.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? right : HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " + hi.getBeginLine() + ")");
        } else // pattern (c) X - s*Y -> X -* sY
        if (bop.getOp() == OpOp2.MINUS && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isScalarMatrixBinaryMult(right) && HopRewriteUtils.isEqualSize(left, right) && // single consumer s*Y
        right.getParent().size() == 1) {
            Hop smid = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 0 : 1);
            Hop mright = right.getInput().get((right.getInput().get(0).getDataType() == DataType.SCALAR) ? 1 : 0);
            ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) smid) == 0) ? left : HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
            LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " + hi.getBeginLine() + ")");
        }
        // rewire parent-child operators if rewrite applied
        if (ternop != null) {
            HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos);
            hi = ternop;
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

BinaryOp (org.apache.sysml.hops.BinaryOp)58 Hop (org.apache.sysml.hops.Hop)57 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)50 LiteralOp (org.apache.sysml.hops.LiteralOp)27 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)23 UnaryOp (org.apache.sysml.hops.UnaryOp)18 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)9 IndexingOp (org.apache.sysml.hops.IndexingOp)6 ReorgOp (org.apache.sysml.hops.ReorgOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 DataOp (org.apache.sysml.hops.DataOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HashMap (java.util.HashMap)3 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3