Search in sources :

Example 11 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class HopRewriteUtils method createSeqDataGenOp.

public static DataGenOp createSeqDataGenOp(Hop input, boolean asc) {
    Hop to = input.rowsKnown() ? new LiteralOp(input.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, input);
    HashMap<String, Hop> params = new HashMap<>();
    if (asc) {
        params.put(Statement.SEQ_FROM, new LiteralOp(1));
        params.put(Statement.SEQ_TO, to);
        params.put(Statement.SEQ_INCR, new LiteralOp(1));
    } else {
        params.put(Statement.SEQ_FROM, to);
        params.put(Statement.SEQ_TO, new LiteralOp(1));
        params.put(Statement.SEQ_INCR, new LiteralOp(-1));
    }
    // note internal refresh size information
    DataGenOp datagen = new DataGenOp(DataGenMethod.SEQ, new DataIdentifier("tmp"), params);
    datagen.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
    copyLineNumbers(input, datagen);
    return datagen;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) DataIdentifier(org.apache.sysml.parser.DataIdentifier) HashMap(java.util.HashMap) DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 12 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColwiseAggregate.

@SuppressWarnings("unchecked")
private static Hop simplifyColwiseAggregate(Hop parent, Hop hi, int pos) {
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE)) {
            if (uhi.getDirection() == Direction.Col) {
                if (input.getDim1() == 1) {
                    if (uhi.getOp() == AggOp.VAR) {
                        // For the column variance aggregation, if the input is a row vector,
                        // the column variances will each be zero.
                        // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros.
                        Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0);
                        HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi, input);
                        hi = emptyRow;
                        LOG.debug("Applied simplifyColwiseAggregate for colVars");
                    } else {
                        // All other valid column aggregations over a row vector will result
                        // in the row vector itself.
                        // Therefore, remove unnecessary col aggregation for 1 row.
                        HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
                        HopRewriteUtils.cleanupUnreferenced(hi);
                        hi = input;
                        LOG.debug("Applied simplifyColwiseAggregate1");
                    }
                } else if (input.getDim2() == 1) {
                    // get old parents (before creating cast over aggregate)
                    ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone();
                    // simplify col-aggregate to full aggregate
                    uhi.setDirection(Direction.RowCol);
                    uhi.setDataType(DataType.SCALAR);
                    // create cast to keep same output datatype
                    UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX);
                    // rehang cast under all parents
                    for (Hop p : parents) {
                        int ix = HopRewriteUtils.getChildReferencePos(p, hi);
                        HopRewriteUtils.replaceChildReference(p, hi, cast, ix);
                    }
                    hi = cast;
                    LOG.debug("Applied simplifyColwiseAggregate2");
                }
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList)

Example 13 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyScalarMatrixMult.

private static Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos) {
    if (// X%*%Y
    HopRewriteUtils.isMatrixMultiply(hi)) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // y %*% X -> as.scalar(y) * X
        if (// scalar left
        HopRewriteUtils.isDimsKnown(left) && left.getDim1() == 1 && left.getDim2() == 1) {
            UnaryOp cast = HopRewriteUtils.createUnary(left, OpOp1.CAST_AS_SCALAR);
            BinaryOp mult = HopRewriteUtils.createBinary(cast, right, OpOp2.MULT);
            // add mult to parent
            HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = mult;
            LOG.debug("Applied simplifyScalarMatrixMult1");
        } else // X %*% y -> X * as.scalar(y)
        if (// scalar right
        HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1) {
            UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR);
            BinaryOp mult = HopRewriteUtils.createBinary(cast, left, OpOp2.MULT);
            // add mult to parent
            HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = mult;
            LOG.debug("Applied simplifyScalarMatrixMult2");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 14 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeUnnecessaryCumulativeOp.

private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) {
    if (hi instanceof UnaryOp && ((UnaryOp) hi).isCumulativeUnaryOperation()) {
        // input matrix
        Hop input = hi.getInput().get(0);
        if (// dims input known
        HopRewriteUtils.isDimsKnown(input) && // 1 row
        input.getDim1() == 1) {
            OpOp1 op = ((UnaryOp) hi).getOp();
            // remove unnecessary unary cumsum operator
            HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
            hi = input;
            LOG.debug("Applied removeUnnecessaryCumulativeOp: " + op);
        }
    }
    return hi;
}
Also used : OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 15 with UnaryOp

use of org.apache.sysml.hops.UnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.

private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    if (// all patterns subrooted by W *
    HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
    hi.getDim2() > 1 && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
    hi.getInput().get(1) instanceof UnaryOp) {
        UnaryOp uop = (UnaryOp) hi.getInput().get(1);
        boolean appliedPattern = false;
        // Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
        if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
        if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
            BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
            if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
                Hop W = hi.getInput().get(0);
                Hop Y = bop.getInput().get(1).getInput().get(0);
                Hop tX = bop.getInput().get(1).getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(tX)) {
                    tX = HopRewriteUtils.createTranspose(tX);
                } else
                    tX = tX.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

UnaryOp (org.apache.sysml.hops.UnaryOp)42 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)39 Hop (org.apache.sysml.hops.Hop)35 LiteralOp (org.apache.sysml.hops.LiteralOp)24 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)18 BinaryOp (org.apache.sysml.hops.BinaryOp)18 DataOp (org.apache.sysml.hops.DataOp)9 IndexingOp (org.apache.sysml.hops.IndexingOp)8 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 ArrayList (java.util.ArrayList)5 HashMap (java.util.HashMap)5 ReorgOp (org.apache.sysml.hops.ReorgOp)5 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)5 DataGenOp (org.apache.sysml.hops.DataGenOp)4 HopsException (org.apache.sysml.hops.HopsException)4 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)4 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)4 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3