Search in sources :

Example 66 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class HopRewriteUtils method rContainsRead.

public static boolean rContainsRead(Hop root, String var, boolean includeMetaOp) {
    if (root.isVisited())
        return false;
    boolean ret = false;
    // handle leaf node for variable
    if (root instanceof DataOp && ((DataOp) root).isRead() && root.getName().equals(var)) {
        boolean onlyMetaOp = true;
        if (!includeMetaOp) {
            for (Hop p : root.getParent()) {
                onlyMetaOp &= (p instanceof UnaryOp && (((UnaryOp) p).getOp() == OpOp1.NROW || ((UnaryOp) p).getOp() == OpOp1.NCOL));
            }
            ret = !onlyMetaOp;
        } else
            ret = true;
    }
    // recursively process childs
    for (Hop c : root.getInput()) ret |= rContainsRead(c, var, includeMetaOp);
    root.setVisited();
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) Hop(org.apache.sysml.hops.Hop) DataOp(org.apache.sysml.hops.DataOp)

Example 67 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class HopRewriteUtils method isBinarySparseSafe.

public static boolean isBinarySparseSafe(Hop hop) {
    if (!(hop instanceof BinaryOp))
        return false;
    if (isBinary(hop, OpOp2.MULT, OpOp2.DIV))
        return true;
    BinaryOp bop = (BinaryOp) hop;
    Hop lit = bop.getInput().get(0) instanceof LiteralOp ? bop.getInput().get(0) : bop.getInput().get(1) instanceof LiteralOp ? bop.getInput().get(1) : null;
    return lit != null && OptimizerUtils.isBinaryOpSparsityConditionalSparseSafe(bop.getOp(), (LiteralOp) lit);
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 68 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class HopRewriteUtils method createSeqDataGenOp.

public static DataGenOp createSeqDataGenOp(Hop input, boolean asc) {
    Hop to = input.rowsKnown() ? new LiteralOp(input.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, input);
    HashMap<String, Hop> params = new HashMap<>();
    if (asc) {
        params.put(Statement.SEQ_FROM, new LiteralOp(1));
        params.put(Statement.SEQ_TO, to);
        params.put(Statement.SEQ_INCR, new LiteralOp(1));
    } else {
        params.put(Statement.SEQ_FROM, to);
        params.put(Statement.SEQ_TO, new LiteralOp(1));
        params.put(Statement.SEQ_INCR, new LiteralOp(-1));
    }
    // note internal refresh size information
    DataGenOp datagen = new DataGenOp(DataGenMethod.SEQ, new DataIdentifier("tmp"), params);
    datagen.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
    copyLineNumbers(input, datagen);
    return datagen;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) DataIdentifier(org.apache.sysml.parser.DataIdentifier) HashMap(java.util.HashMap) DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 69 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeEmptyRightIndexing.

private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) {
    if (// indexing op
    hi instanceof IndexingOp && hi.getDataType() == DataType.MATRIX) {
        Hop input = hi.getInput().get(0);
        if (// nnz input known and empty
        input.getNnz() == 0 && // output dims known
        HopRewriteUtils.isDimsKnown(hi)) {
            // remove unnecessary right indexing
            Hop hnew = HopRewriteUtils.createDataGenOpByVal(new LiteralOp(hi.getDim1()), new LiteralOp(hi.getDim2()), 0);
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, input);
            hi = hnew;
            LOG.debug("Applied removeEmptyRightIndexing");
        }
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 70 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.

/**
 * Searches for weighted squared loss expressions and replaces them with a quaternary operator.
 * Currently, this search includes the following three patterns:
 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
 *
 * NOTE: We include transpose into the pattern because during runtime we need to compute
 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
 * without additional memory requirements for internal transpose.
 *
 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
 * as we generalized the runtime or hop/lop compilation.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
    // NOTE: there might be also a general simplification without custom operator
    // via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
    Hop hnew = null;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        boolean appliedPattern = false;
        // alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
        if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
            Hop W = bop.getInput().get(0);
            // (X - U %*% t(V))
            Hop tmp = bop.getInput().get(1).getInput().get(0);
            if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
            HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
                // a) sum (W * (X - U %*% t(V)) ^ 2)
                int uvIndex = -1;
                if (// ba gurantees matrices
                tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    uvIndex = 1;
                } else // b) sum (W * (U %*% t(V) - X) ^ 2)
                if (// ba gurantees matrices
                tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
                    uvIndex = 0;
                }
                if (// rewrite match
                uvIndex >= 0) {
                    Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
                    Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
                    Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    // handle special case of post_nz
                    if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
                        W = new LiteralOp(1);
                    }
                    // construct quaternary hop
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - W * (U %*% t(V))) ^ 2)
            int wuvIndex = -1;
            if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 1;
            } else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
            if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 0;
            }
            if (// rewrite match
            wuvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
                // (W * (U %*% t(V)))
                Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
                if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
                HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    Hop W = tmp.getInput().get(0);
                    Hop U = tmp.getInput().get(1).getInput().get(0);
                    Hop V = tmp.getInput().get(1).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - (U %*% t(V))) ^ 2)
            int uvIndex = -1;
            if (// ba gurantees matrices
            lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
                uvIndex = 1;
            } else // b) sum (((U %*% t(V)) - X) ^ 2)
            if (// ba gurantees matrices
            lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
                uvIndex = 0;
            }
            if (// rewrite match
            uvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
                // (U %*% t(V))
                Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
                // no weighting
                Hop W = new LiteralOp(1);
                Hop U = tmp.getInput().get(0);
                Hop V = tmp.getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(V)) {
                    V = HopRewriteUtils.createTranspose(V);
                } else {
                    V = V.getInput().get(0);
                }
                hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                HopRewriteUtils.setOutputParametersForScalar(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

Hop (org.apache.sysml.hops.Hop)307 LiteralOp (org.apache.sysml.hops.LiteralOp)94 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)65 BinaryOp (org.apache.sysml.hops.BinaryOp)63 ArrayList (java.util.ArrayList)61 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)61 HashMap (java.util.HashMap)44 DataOp (org.apache.sysml.hops.DataOp)41 UnaryOp (org.apache.sysml.hops.UnaryOp)41 HashSet (java.util.HashSet)39 ReorgOp (org.apache.sysml.hops.ReorgOp)32 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)28 StatementBlock (org.apache.sysml.parser.StatementBlock)28 IndexingOp (org.apache.sysml.hops.IndexingOp)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)23 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)23 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 DataGenOp (org.apache.sysml.hops.DataGenOp)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HopsException (org.apache.sysml.hops.HopsException)18