Search in sources :

Example 6 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class TemplateOuterProduct method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    // recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OUTER, TemplateType.CELL);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        String primitiveOpName = ((UnaryOp) hop).getOp().toString();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = ((BinaryOp) hop).getOp().toString();
        if (HopRewriteUtils.isBinarySparseSafe(hop)) {
            if (TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData)
                inHops2.put("_X", hop.getInput().get(0));
            if (TemplateUtils.isMatrix(hop.getInput().get(1)) && cdata2 instanceof CNodeData)
                inHops2.put("_X", hop.getInput().get(1));
        }
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        // handle transpose in outer or final product
        cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
        cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
        // outer product U%*%t(V), see open
        if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
            // keep U and V for later reference
            inHops2.put("_U", hop.getInput().get(0));
            if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
                inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
            else
                inHops2.put("_V", hop.getInput().get(1));
            out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
        } else // final left/right matrix mult, see close
        {
            if (cdata1.getDataType().isScalar())
                out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
            else
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) BinaryOp(org.apache.sysml.hops.BinaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 7 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class TemplateUtils method getRowType.

public static RowType getRowType(Hop output, Hop... inputs) {
    Hop X = inputs[0];
    Hop B1 = (inputs.length > 1) ? inputs[1] : null;
    if ((X != null && HopRewriteUtils.isEqualSize(output, X)) || X == null || !X.dimsKnown())
        return RowType.NO_AGG;
    else if (((B1 != null && output.getDim1() == X.getDim1() && output.getDim2() == B1.getDim2()) || (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp) output))) && !(output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0), X)))
        return RowType.NO_AGG_B1;
    else if (output.getDim1() == X.getDim1() && (output.getDim2() == 1) && !(output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0), X)))
        return RowType.ROW_AGG;
    else if (output instanceof AggUnaryOp && ((AggUnaryOp) output).getDirection() == Direction.RowCol)
        return RowType.FULL_AGG;
    else if (output.getDim1() == X.getDim2() && output.getDim2() == 1)
        return RowType.COL_AGG_T;
    else if (output.getDim1() == 1 && output.getDim2() == X.getDim2())
        return RowType.COL_AGG;
    else if (B1 != null && output.getDim1() == X.getDim2() && output.getDim2() == B1.getDim2())
        return RowType.COL_AGG_B1_T;
    else if (B1 != null && output.getDim1() == B1.getDim2() && output.getDim2() == X.getDim2())
        return RowType.COL_AGG_B1;
    else if (B1 != null && output.getDim1() == 1 && B1.getDim2() == output.getDim2())
        return RowType.COL_AGG_B1R;
    else if (X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2())
        return RowType.NO_AGG_CONST;
    else if (output.getDim1() == 1 && X.getDim2() != output.getDim2())
        return RowType.COL_AGG_CONST;
    else
        throw new RuntimeException("Unknown row type for hop " + output.getHopID() + ".");
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop)

Example 8 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class HopRewriteUtils method createMatrixMultiply.

public static AggBinaryOp createMatrixMultiply(Hop left, Hop right) {
    AggBinaryOp mmult = new AggBinaryOp(left.getName(), left.getDataType(), left.getValueType(), OpOp2.MULT, AggOp.SUM, left, right);
    mmult.setOutputBlocksizes(left.getRowsInBlock(), right.getColsInBlock());
    copyLineNumbers(left, mmult);
    mmult.refreshSizeInformation();
    return mmult;
}
Also used : AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 9 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.

/**
 * Searches for weighted squared loss expressions and replaces them with a quaternary operator.
 * Currently, this search includes the following three patterns:
 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
 *
 * NOTE: We include transpose into the pattern because during runtime we need to compute
 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
 * without additional memory requirements for internal transpose.
 *
 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
 * as we generalized the runtime or hop/lop compilation.
 *
 * @param parent parent high-level operator
 * @param hi high-level operator
 * @param pos position
 * @return high-level operator
 */
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
    // NOTE: there might be also a general simplification without custom operator
    // via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
    Hop hnew = null;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        boolean appliedPattern = false;
        // alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
        if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
            Hop W = bop.getInput().get(0);
            // (X - U %*% t(V))
            Hop tmp = bop.getInput().get(1).getInput().get(0);
            if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
            HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
                // a) sum (W * (X - U %*% t(V)) ^ 2)
                int uvIndex = -1;
                if (// ba gurantees matrices
                tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    uvIndex = 1;
                } else // b) sum (W * (U %*% t(V) - X) ^ 2)
                if (// ba gurantees matrices
                tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
                    uvIndex = 0;
                }
                if (// rewrite match
                uvIndex >= 0) {
                    Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
                    Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
                    Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    // handle special case of post_nz
                    if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
                        W = new LiteralOp(1);
                    }
                    // construct quaternary hop
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - W * (U %*% t(V))) ^ 2)
            int wuvIndex = -1;
            if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 1;
            } else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
            if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 0;
            }
            if (// rewrite match
            wuvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
                // (W * (U %*% t(V)))
                Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
                if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
                HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    Hop W = tmp.getInput().get(0);
                    Hop U = tmp.getInput().get(1).getInput().get(0);
                    Hop V = tmp.getInput().get(1).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        // alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            // a) sum ((X - (U %*% t(V))) ^ 2)
            int uvIndex = -1;
            if (// ba gurantees matrices
            lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
                uvIndex = 1;
            } else // b) sum (((U %*% t(V)) - X) ^ 2)
            if (// ba gurantees matrices
            lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
                uvIndex = 0;
            }
            if (// rewrite match
            uvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
                // (U %*% t(V))
                Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
                // no weighting
                Hop W = new LiteralOp(1);
                Hop U = tmp.getInput().get(0);
                Hop V = tmp.getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(V)) {
                    V = HopRewriteUtils.createTranspose(V);
                } else {
                    V = V.getInput().get(0);
                }
                hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                HopRewriteUtils.setOutputParametersForScalar(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 10 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifySumMatrixMult.

private static Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) {
    // -- if sum not the only consumer, not applied to prevent redundancy
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // A%*%B
    hi.getInput().get(0) instanceof AggBinaryOp && // not dot product
    (hi.getInput().get(0).getDim1() > 1 || hi.getInput().get(0).getDim2() > 1) && // not multiple consumers of matrix mult
    hi.getInput().get(0).getParent().size() == 1) {
        Hop hi2 = hi.getInput().get(0);
        Hop left = hi2.getInput().get(0);
        Hop right = hi2.getInput().get(1);
        // remove link from parent to matrix mult
        HopRewriteUtils.removeChildReference(hi, hi2);
        // create new operators
        Hop root = null;
        // pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
        if (((AggUnaryOp) hi).getDirection() == Direction.RowCol) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
            LOG.debug("Applied simplifySumMatrixMult RC.");
        } else // colSums(A%*%B) -> colSums(A)%*%B
        if (((AggUnaryOp) hi).getDirection() == Direction.Col) {
            AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
            root = HopRewriteUtils.createMatrixMultiply(colSum, right);
            LOG.debug("Applied simplifySumMatrixMult C.");
        } else // rowSums(A%*%B) -> A%*%rowSums(B)
        if (((AggUnaryOp) hi).getDirection() == Direction.Row) {
            AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
            root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
            LOG.debug("Applied simplifySumMatrixMult R.");
        }
        // rehang new subdag under current node (keep hi intact)
        HopRewriteUtils.addChildReference(hi, root, 0);
        hi.refreshSizeInformation();
        // cleanup if only consumer of intermediate
        HopRewriteUtils.cleanupUnreferenced(hi2);
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)29 Hop (org.apache.sysml.hops.Hop)24 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)17 BinaryOp (org.apache.sysml.hops.BinaryOp)14 LiteralOp (org.apache.sysml.hops.LiteralOp)11 ReorgOp (org.apache.sysml.hops.ReorgOp)8 UnaryOp (org.apache.sysml.hops.UnaryOp)8 TernaryOp (org.apache.sysml.hops.TernaryOp)6 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 IndexingOp (org.apache.sysml.hops.IndexingOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)5 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5 CNode (org.apache.sysml.hops.codegen.cplan.CNode)4 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 DataOp (org.apache.sysml.hops.DataOp)2 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)2