Search in sources :

Example 6 with QuaternaryOp

use of org.apache.sysml.hops.QuaternaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.

private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // Pattern 1) sum( X * log(U %*% t(V)))
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
        HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
        right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) sum( X * log(U %*% t(V) + eps))
        if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
            Hop eps = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
            false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 7 with QuaternaryOp

use of org.apache.sysml.hops.QuaternaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.

private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // Pattern 1) (W*uop(U%*%t(V)))
    if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
        boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
        OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
        if (!HopRewriteUtils.isTransposeOperation(V))
            V = HopRewriteUtils.createTranspose(V);
        else
            V = V.getInput().get(0);
        hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
        hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
        hnew.refreshSizeInformation();
        appliedPattern = true;
        LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
    }
    // Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
        // non-literal
        final Hop nl;
        if (hi.getInput().get(0) instanceof LiteralOp) {
            nl = hi.getInput().get(1);
        } else {
            nl = hi.getInput().get(0);
        }
        if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
        nl.getParent().size() == 1 && // prevent mv
        HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
        nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
        (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
            final Hop W = nl.getInput().get(0);
            final Hop U = nl.getInput().get(1).getInput().get(0);
            Hop V = nl.getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
        }
    }
    // Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        Hop abop = null;
        // pattern 2a) matrix-scalar operations
        if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
            abop = left;
        } else // pattern 2b) scalar-matrix operations
        if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
            abop = right;
        }
        if (abop != null) {
            Hop W = hi.getInput().get(0);
            Hop U = abop.getInput().get(0);
            Hop V = abop.getInput().get(1);
            boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
            OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 8 with QuaternaryOp

use of org.apache.sysml.hops.QuaternaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedDivMM.

private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
    if (HopRewriteUtils.isMatrixMultiply(hi) && (hi.getInput().get(0) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) || hi.getInput().get(1) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getDim2() > 1 && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY))) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // alternative pattern: t(U) %*% (W*(U%*%t(V)))
        if (right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) right).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
        HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                boolean mult = ((BinaryOp) right).getOp() == OpOp2.MULT;
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM1 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
        if (!appliedPattern && // DIV
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
        HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 3, false, // 3=>DIV_LEFT_EPS
                false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM1e (line " + hi.getBeginLine() + ")");
            }
        }
        // alternative pattern: (W*(U%*%t(V))) %*% V
        if (!appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
        HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                boolean mult = ((BinaryOp) left).getOp() == OpOp2.MULT;
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM2 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 2e) (W/(U%*%t(V) + x)) %*% V
        if (!appliedPattern && // DIV
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) && // prevent mv
        HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 4, false, // 4=>DIV_RIGHT_EPS
                false);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM2e (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (// W-X constraint
            HopRewriteUtils.isNonZeroIndicator(W, X) && // t(U)-U constraint
            HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM3 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (// W-X constraint
            HopRewriteUtils.isNonZeroIndicator(W, X) && // V-t(V) constraint
            HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM4 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = right.getInput().get(0);
            Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = right.getInput().get(1).getInput().get(1);
            if (// t(U)-U constraint
            HopRewriteUtils.isTransposeOfItself(left, U)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = HopRewriteUtils.createTranspose(V);
                else
                    V = V.getInput().get(0);
                // note: x and w exchanged compared to patterns 1-4, 7
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 1, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                // add output transpose for efficient target indexing (redundant t() removed by other rewrites)
                hnew = HopRewriteUtils.createTranspose(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM5 (line " + hi.getBeginLine() + ")");
            }
        }
        // Pattern 6) (W*(U%*%t(V)-X)) %*% V
        if (!appliedPattern && // MULT
        HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
            Hop W = left.getInput().get(0);
            Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
            Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
            Hop X = left.getInput().get(1).getInput().get(1);
            if (// V-t(V) constraint
            HopRewriteUtils.isTransposeOfItself(right, V)) {
                if (!HopRewriteUtils.isTransposeOperation(V))
                    V = right;
                else
                    V = V.getInput().get(0);
                // note: x and w exchanged compared to patterns 1-4, 7
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, X, 2, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedDivMM6 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // Pattern 7) (W*(U%*%t(V)))
    if (!appliedPattern && // MULT
    HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) && // no mmchain
    (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(1);
        // W is sparse and U/V unknown or dense; or if U/V are dense
        if ((HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V))) {
            V = !HopRewriteUtils.isTransposeOperation(V) ? HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedDivMM7 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 9 with QuaternaryOp

use of org.apache.sysml.hops.QuaternaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.

private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    if (// all patterns subrooted by W *
    HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
    hi.getDim2() > 1 && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
    hi.getInput().get(1) instanceof UnaryOp) {
        UnaryOp uop = (UnaryOp) hi.getInput().get(1);
        boolean appliedPattern = false;
        // Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
        if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
        if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
            BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
            if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
                Hop W = hi.getInput().get(0);
                Hop Y = bop.getInput().get(1).getInput().get(0);
                Hop tX = bop.getInput().get(1).getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(tX)) {
                    tX = HopRewriteUtils.createTranspose(tX);
                } else
                    tX = tX.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 10 with QuaternaryOp

use of org.apache.sysml.hops.QuaternaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.

private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // Pattern 1) (W*uop(U%*%t(V)))
    if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
        boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
        OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
        if (!HopRewriteUtils.isTransposeOperation(V))
            V = HopRewriteUtils.createTranspose(V);
        else
            V = V.getInput().get(0);
        hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
        hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
        hnew.refreshSizeInformation();
        appliedPattern = true;
        LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
    }
    // Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
        // non-literal
        final Hop nl;
        if (hi.getInput().get(0) instanceof LiteralOp) {
            nl = hi.getInput().get(1);
        } else {
            nl = hi.getInput().get(0);
        }
        if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
        nl.getParent().size() == 1 && // prevent mv
        HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
        nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
        (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
            final Hop W = nl.getInput().get(0);
            final Hop U = nl.getInput().get(1).getInput().get(0);
            Hop V = nl.getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
        }
    }
    // Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        Hop abop = null;
        // pattern 2a) matrix-scalar operations
        if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
            abop = left;
        } else // pattern 2b) scalar-matrix operations
        if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
            abop = right;
        }
        if (abop != null) {
            Hop W = hi.getInput().get(0);
            Hop U = abop.getInput().get(0);
            Hop V = abop.getInput().get(1);
            boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
            OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)10 BinaryOp (org.apache.sysml.hops.BinaryOp)10 Hop (org.apache.sysml.hops.Hop)10 LiteralOp (org.apache.sysml.hops.LiteralOp)10 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)10 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)8 UnaryOp (org.apache.sysml.hops.UnaryOp)4 OpOp1 (org.apache.sysml.hops.Hop.OpOp1)2 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2