Search in sources :

Example 1 with OpOp1

use of org.apache.sysml.hops.Hop.OpOp1 in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryCumulativeOp.

private Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) {
    if (hi instanceof UnaryOp && ((UnaryOp) hi).isCumulativeUnaryOperation()) {
        //input matrix
        Hop input = hi.getInput().get(0);
        if (//dims input known
        HopRewriteUtils.isDimsKnown(input) && //1 row
        input.getDim1() == 1) {
            OpOp1 op = ((UnaryOp) hi).getOp();
            //remove unnecessary unary cumsum operator				
            HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
            hi = input;
            LOG.debug("Applied removeUnnecessaryCumulativeOp: " + op);
        }
    }
    return hi;
}
Also used : OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 2 with OpOp1

use of org.apache.sysml.hops.Hop.OpOp1 in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.

private Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) throws HopsException {
    Hop hnew = null;
    boolean appliedPattern = false;
    //Pattern 1) (W*uop(U%*%t(V)))
    if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && //prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && //not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT			
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
        boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
        OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
        if (!HopRewriteUtils.isTransposeOperation(V))
            V = HopRewriteUtils.createTranspose(V);
        else
            V = V.getInput().get(0);
        hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
        hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
        hnew.refreshSizeInformation();
        appliedPattern = true;
        LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
    }
    //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && //prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && //not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        Hop abop = null;
        //pattern 2a) matrix-scalar operations
        if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && //pow2, mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT			
        HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
            abop = left;
        } else //pattern 2b) scalar-matrix operations
        if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && //mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT			
        HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
            abop = right;
        }
        if (abop != null) {
            Hop W = hi.getInput().get(0);
            Hop U = abop.getInput().get(0);
            Hop V = abop.getInput().get(1);
            boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
            OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
        }
    }
    //relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)2 Hop (org.apache.sysml.hops.Hop)2 OpOp1 (org.apache.sysml.hops.Hop.OpOp1)2 UnaryOp (org.apache.sysml.hops.UnaryOp)2 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)1 BinaryOp (org.apache.sysml.hops.BinaryOp)1 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)1 LiteralOp (org.apache.sysml.hops.LiteralOp)1 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)1