Search in sources :

Example 21 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeScalarAggregate.

private static StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
    StatementBlock ret = sb;
    // check missing and supported increment values
    if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    // check for applicability
    boolean leftScalar = false;
    boolean rightScalar = false;
    // row or col
    boolean rowIx = false;
    if (csb.getHops() != null && csb.getHops().size() == 1) {
        Hop root = csb.getHops().get(0);
        if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
            BinaryOp bop = (BinaryOp) root.getInput().get(0);
            Hop left = bop.getInput().get(0);
            Hop right = bop.getInput().get(1);
            // check for left scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) right.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = false;
                }
            } else // check for right scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) left.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = false;
                }
            }
        }
    }
    // apply rewrite if possible
    if (leftScalar || rightScalar) {
        Hop root = csb.getHops().get(0);
        BinaryOp bop = (BinaryOp) root.getInput().get(0);
        Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
        Hop ix = cast.getInput().get(0);
        int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
        AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
        // replace cast with sum
        AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
        HopRewriteUtils.removeChildReference(cast, ix);
        HopRewriteUtils.removeChildReference(bop, cast);
        HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
        // modify indexing expression according to loop predicate from-to
        // NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
        int index1 = rowIx ? 1 : 3;
        int index2 = rowIx ? 2 : 4;
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
        // update indexing size information
        if (rowIx)
            ((IndexingOp) ix).setRowLowerEqualsUpper(false);
        else
            ((IndexingOp) ix).setColLowerEqualsUpper(false);
        ix.refreshSizeInformation();
        ret = csb;
        LOG.debug("Applied vectorizeScalarSumForLoop.");
    }
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggOp(org.apache.sysml.hops.Hop.AggOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 22 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class LiteralReplacement method replaceLiteralFullUnaryAggregateRightIndexing.

private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing(Hop c, LocalVariableMap vars) {
    LiteralOp ret = null;
    // full unary aggregate w/ indexed matrix less than 10^6 cells
    if (c instanceof AggUnaryOp && isReplaceableUnaryAggregate((AggUnaryOp) c) && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getInput().get(0) instanceof DataOp) {
        IndexingOp rix = (IndexingOp) c.getInput().get(0);
        Hop data = rix.getInput().get(0);
        Hop rl = rix.getInput().get(1);
        Hop ru = rix.getInput().get(2);
        Hop cl = rix.getInput().get(3);
        Hop cu = rix.getInput().get(4);
        if (data instanceof DataOp && vars.keySet().contains(data.getName()) && isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) && isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars)) {
            long rlval = getIntValueDataLiteral(rl, vars);
            long ruval = getIntValueDataLiteral(ru, vars);
            long clval = getIntValueDataLiteral(cl, vars);
            long cuval = getIntValueDataLiteral(cu, vars);
            MatrixObject mo = (MatrixObject) vars.get(data.getName());
            // dimensions might not have been updated during recompile
            if (mo.getNumRows() * mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) {
                MatrixBlock mBlock = mo.acquireRead();
                MatrixBlock mBlock2 = mBlock.slice((int) (rlval - 1), (int) (ruval - 1), (int) (clval - 1), (int) (cuval - 1), new MatrixBlock());
                double value = replaceUnaryAggregate((AggUnaryOp) c, mBlock2);
                mo.release();
                // literal substitution (always double)
                ret = new LiteralOp(value);
            }
        }
    }
    return ret;
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp)

Example 23 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class HopRewriteUtils method createAggUnaryOp.

public static AggUnaryOp createAggUnaryOp(Hop input, AggOp op, Direction dir) {
    DataType dt = (dir == Direction.RowCol) ? DataType.SCALAR : input.getDataType();
    AggUnaryOp auop = new AggUnaryOp(input.getName(), dt, input.getValueType(), op, dir, input);
    auop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
    copyLineNumbers(input, auop);
    auop.refreshSizeInformation();
    return auop;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) DataType(org.apache.sysml.parser.Expression.DataType)

Example 24 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.

private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // Pattern 1) sum( X * log(U %*% t(V)))
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
        HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
        right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) sum( X * log(U %*% t(V) + eps))
        if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
            Hop eps = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
            false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 25 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyUnnecessaryAggregate.

private static Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) {
    // e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace)
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE)) {
            if (input.getDim1() == 1 && input.getDim2() == 1) {
                UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR);
                // remove unnecessary aggregation
                HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
                hi = cast;
                LOG.debug("Applied simplifyUnncessaryAggregate");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)36 Hop (org.apache.sysml.hops.Hop)33 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)19 LiteralOp (org.apache.sysml.hops.LiteralOp)14 BinaryOp (org.apache.sysml.hops.BinaryOp)12 ReorgOp (org.apache.sysml.hops.ReorgOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 ArrayList (java.util.ArrayList)5 TernaryOp (org.apache.sysml.hops.TernaryOp)5 DataOp (org.apache.sysml.hops.DataOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2