Search in sources :

Example 21 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class HopRewriteUtils method copyDataGenOp.

/**
 * Assumes that min and max are literal ops, needs to be checked from outside.
 *
 * @param inputGen input data gen op
 * @param scale the scale
 * @param shift the shift
 * @return data gen op
 */
public static DataGenOp copyDataGenOp(DataGenOp inputGen, double scale, double shift) {
    HashMap<String, Integer> params = inputGen.getParamIndexMap();
    Hop rows = inputGen.getInput().get(params.get(DataExpression.RAND_ROWS));
    Hop cols = inputGen.getInput().get(params.get(DataExpression.RAND_COLS));
    Hop min = inputGen.getInput().get(params.get(DataExpression.RAND_MIN));
    Hop max = inputGen.getInput().get(params.get(DataExpression.RAND_MAX));
    Hop pdf = inputGen.getInput().get(params.get(DataExpression.RAND_PDF));
    Hop mean = inputGen.getInput().get(params.get(DataExpression.RAND_LAMBDA));
    Hop sparsity = inputGen.getInput().get(params.get(DataExpression.RAND_SPARSITY));
    Hop seed = inputGen.getInput().get(params.get(DataExpression.RAND_SEED));
    // check for literal ops
    if (!(min instanceof LiteralOp) || !(max instanceof LiteralOp))
        return null;
    // scale and shift
    double smin = getDoubleValue((LiteralOp) min);
    double smax = getDoubleValue((LiteralOp) max);
    smin = smin * scale + shift;
    smax = smax * scale + shift;
    Hop sminHop = new LiteralOp(smin);
    Hop smaxHop = new LiteralOp(smax);
    HashMap<String, Hop> params2 = new HashMap<>();
    params2.put(DataExpression.RAND_ROWS, rows);
    params2.put(DataExpression.RAND_COLS, cols);
    params2.put(DataExpression.RAND_MIN, sminHop);
    params2.put(DataExpression.RAND_MAX, smaxHop);
    params2.put(DataExpression.RAND_PDF, pdf);
    params2.put(DataExpression.RAND_LAMBDA, mean);
    params2.put(DataExpression.RAND_SPARSITY, sparsity);
    params2.put(DataExpression.RAND_SEED, seed);
    // note internal refresh size information
    DataGenOp datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params2);
    datagen.setOutputBlocksizes(inputGen.getRowsInBlock(), inputGen.getColsInBlock());
    copyLineNumbers(inputGen, datagen);
    if (smin == 0 && smax == 0)
        datagen.setNnz(0);
    return datagen;
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) HashMap(java.util.HashMap) DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 22 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class HopRewriteUtils method isBinarySparseSafe.

public static boolean isBinarySparseSafe(Hop hop) {
    if (!(hop instanceof BinaryOp))
        return false;
    if (isBinary(hop, OpOp2.MULT, OpOp2.DIV))
        return true;
    BinaryOp bop = (BinaryOp) hop;
    Hop lit = bop.getInput().get(0) instanceof LiteralOp ? bop.getInput().get(0) : bop.getInput().get(1) instanceof LiteralOp ? bop.getInput().get(1) : null;
    return lit != null && OptimizerUtils.isBinaryOpSparsityConditionalSparseSafe(bop.getOp(), (LiteralOp) lit);
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 23 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class HopRewriteUtils method createSeqDataGenOp.

public static DataGenOp createSeqDataGenOp(Hop input, boolean asc) {
    Hop to = input.rowsKnown() ? new LiteralOp(input.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, input);
    HashMap<String, Hop> params = new HashMap<>();
    if (asc) {
        params.put(Statement.SEQ_FROM, new LiteralOp(1));
        params.put(Statement.SEQ_TO, to);
        params.put(Statement.SEQ_INCR, new LiteralOp(1));
    } else {
        params.put(Statement.SEQ_FROM, to);
        params.put(Statement.SEQ_TO, new LiteralOp(1));
        params.put(Statement.SEQ_INCR, new LiteralOp(-1));
    }
    // note internal refresh size information
    DataGenOp datagen = new DataGenOp(DataGenMethod.SEQ, new DataIdentifier("tmp"), params);
    datagen.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
    copyLineNumbers(input, datagen);
    return datagen;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) DataIdentifier(org.apache.sysml.parser.DataIdentifier) HashMap(java.util.HashMap) DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 24 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method removeEmptyRightIndexing.

private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) {
    if (// indexing op
    hi instanceof IndexingOp && hi.getDataType() == DataType.MATRIX) {
        Hop input = hi.getInput().get(0);
        if (// nnz input known and empty
        input.getNnz() == 0 && // output dims known
        HopRewriteUtils.isDimsKnown(hi)) {
            // remove unnecessary right indexing
            Hop hnew = HopRewriteUtils.createDataGenOpByVal(new LiteralOp(hi.getDim1()), new LiteralOp(hi.getDim2()), 0);
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, input);
            hi = hnew;
            LOG.debug("Applied removeEmptyRightIndexing");
        }
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 25 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.

/**
 * Searches for weighted squared loss expressions and replaces them with a quaternary operator.
 * Currently, this search includes the following three patterns:
 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
 *
 * NOTE: We include transpose into the pattern because during runtime we need to compute
 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
 * without additional memory requirements for internal transpose.
 *
 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
 * as we generalized the runtime or hop/lop compilation.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
    // NOTE: there might be also a general simplification without custom operator
    // via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
    Hop hnew = null;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        boolean appliedPattern = false;
        // alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
        if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
            Hop W = bop.getInput().get(0);
            // (X - U %*% t(V))
            Hop tmp = bop.getInput().get(1).getInput().get(0);
            if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
            HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
                // a) sum (W * (X - U %*% t(V)) ^ 2)
                int uvIndex = -1;
                if (// ba gurantees matrices
                tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    uvIndex = 1;
                } else // b) sum (W * (U %*% t(V) - X) ^ 2)
                if (// ba gurantees matrices
                tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
                    uvIndex = 0;
                }
                if (// rewrite match
                uvIndex >= 0) {
                    Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
                    Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
                    Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    // handle special case of post_nz
                    if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
                        W = new LiteralOp(1);
                    }
                    // construct quaternary hop
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - W * (U %*% t(V))) ^ 2)
            int wuvIndex = -1;
            if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 1;
            } else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
            if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 0;
            }
            if (// rewrite match
            wuvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
                // (W * (U %*% t(V)))
                Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
                if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
                HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    Hop W = tmp.getInput().get(0);
                    Hop U = tmp.getInput().get(1).getInput().get(0);
                    Hop V = tmp.getInput().get(1).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - (U %*% t(V))) ^ 2)
            int uvIndex = -1;
            if (// ba gurantees matrices
            lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
                uvIndex = 1;
            } else // b) sum (((U %*% t(V)) - X) ^ 2)
            if (// ba gurantees matrices
            lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
                uvIndex = 0;
            }
            if (// rewrite match
            uvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
                // (U %*% t(V))
                Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
                // no weighting
                Hop W = new LiteralOp(1);
                Hop U = tmp.getInput().get(0);
                Hop V = tmp.getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(V)) {
                    V = HopRewriteUtils.createTranspose(V);
                } else {
                    V = V.getInput().get(0);
                }
                hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                HopRewriteUtils.setOutputParametersForScalar(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

LiteralOp (org.apache.sysml.hops.LiteralOp)102 Hop (org.apache.sysml.hops.Hop)88 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)30 BinaryOp (org.apache.sysml.hops.BinaryOp)27 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)25 UnaryOp (org.apache.sysml.hops.UnaryOp)24 DataOp (org.apache.sysml.hops.DataOp)23 IndexingOp (org.apache.sysml.hops.IndexingOp)17 ArrayList (java.util.ArrayList)13 DataGenOp (org.apache.sysml.hops.DataGenOp)13 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)12 HashMap (java.util.HashMap)11 DataIdentifier (org.apache.sysml.parser.DataIdentifier)10 ReorgOp (org.apache.sysml.hops.ReorgOp)9 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)9 HopsException (org.apache.sysml.hops.HopsException)8 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)8 StatementBlock (org.apache.sysml.parser.StatementBlock)8 TernaryOp (org.apache.sysml.hops.TernaryOp)7 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6