Search in sources :

Example 76 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryCumulativeOp.

private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) {
    if (hi instanceof UnaryOp && ((UnaryOp) hi).isCumulativeUnaryOperation()) {
        // input matrix
        Hop input = hi.getInput().get(0);
        if (// dims input known
        HopRewriteUtils.isDimsKnown(input) && // 1 row
        input.getDim1() == 1) {
            OpOp1 op = ((UnaryOp) hi).getOp();
            // remove unnecessary unary cumsum operator
            HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
            hi = input;
            LOG.debug("Applied removeUnnecessaryCumulativeOp: " + op);
        }
    }
    return hi;
}
Also used : OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 77 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyNrowNcolComputation.

private static Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) {
    // even if the intermediate is otherwise not required, e.g., when part of a fused operator)
    if (hi instanceof UnaryOp) {
        if (((UnaryOp) hi).getOp() == OpOp1.NROW && hi.getInput().get(0).rowsKnown()) {
            Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
            HopRewriteUtils.cleanupUnreferenced(hi);
            LOG.debug("Applied simplifyNrowComputation nrow(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
            hi = hnew;
        } else if (((UnaryOp) hi).getOp() == OpOp1.NCOL && hi.getInput().get(0).colsKnown()) {
            Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false);
            HopRewriteUtils.cleanupUnreferenced(hi);
            LOG.debug("Applied simplifyNcolComputation ncol(" + hi.getHopID() + ") -> " + hnew.getName() + " (line " + hi.getBeginLine() + ").");
            hi = hnew;
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 78 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyUnnecessaryAggregate.

private static Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) {
    // e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace)
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE)) {
            if (input.getDim1() == 1 && input.getDim2() == 1) {
                UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
                // remove unnecessary aggregation
                HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
                hi = cast;
                LOG.debug("Applied simplifyUnncessaryAggregate");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 79 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.

private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    if (// all patterns subrooted by W *
    HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
    hi.getDim2() > 1 && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
    hi.getInput().get(1) instanceof UnaryOp) {
        UnaryOp uop = (UnaryOp) hi.getInput().get(1);
        boolean appliedPattern = false;
        // Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
        if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
        if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
            BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
            if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
                Hop W = hi.getInput().get(0);
                Hop Y = bop.getInput().get(1).getInput().get(0);
                Hop tX = bop.getInput().get(1).getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(tX)) {
                    tX = HopRewriteUtils.createTranspose(tX);
                } else
                    tX = tX.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 80 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColwiseAggregate.

@SuppressWarnings("unchecked")
private static Hop simplifyColwiseAggregate(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
            if (uhi.getDirection() == Direction.Col) {
                if (input.getDim1() == 1) {
                    if (uhi.getOp() == AggOp.VAR) {
                        // For the column variance aggregation, if the input is a row vector,
                        // the column variances will each be zero.
                        // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros.
                        Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0);
                        HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi, input);
                        hi = emptyRow;
                        LOG.debug("Applied simplifyColwiseAggregate for colVars");
                    } else {
                        // All other valid column aggregations over a row vector will result
                        // in the row vector itself.
                        // Therefore, remove unnecessary col aggregation for 1 row.
                        HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi);
                        hi = input;
                        LOG.debug("Applied simplifyColwiseAggregate1");
                    }
                } else if (input.getDim2() == 1) {
                    // get old parents (before creating cast over aggregate)
                    ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
                    // simplify col-aggregate to full aggregate
                    uhi.setDirection(Direction.RowCol);
                    uhi.setDataType(DataType.SCALAR);
                    // create cast to keep same output datatype
                    UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
                    // rehang cast under all parents
                    for (Hop p : parents) {
                        int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                        HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
                    }
                    hi = cast;
                    LOG.debug("Applied simplifyColwiseAggregate2");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList)

Aggregations

UnaryOp (org.apache.sysml.hops.UnaryOp)82 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)76 Hop (org.apache.sysml.hops.Hop)68 LiteralOp (org.apache.sysml.hops.LiteralOp)47 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)34 BinaryOp (org.apache.sysml.hops.BinaryOp)34 DataOp (org.apache.sysml.hops.DataOp)18 IndexingOp (org.apache.sysml.hops.IndexingOp)15 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)11 TernaryOp (org.apache.sysml.hops.TernaryOp)11 ArrayList (java.util.ArrayList)10 HashMap (java.util.HashMap)10 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)10 ReorgOp (org.apache.sysml.hops.ReorgOp)9 DataGenOp (org.apache.sysml.hops.DataGenOp)8 HopsException (org.apache.sysml.hops.HopsException)8 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)8 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)8 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)6 CNode (org.apache.sysml.hops.codegen.cplan.CNode)6