Search in sources :

Example 1 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class TemplateCell method rConstructCplan.

protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
    //memoization for common subexpression elimination and to avoid redundant work 
    if (tmp.containsKey(hop.getHopID()))
        return;
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl);
    //recursively process required childs
    if (me != null && (me.type == TemplateType.RowTpl || me.type == TemplateType.OuterProdTpl)) {
        CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
        tmp.put(hop.getHopID(), cdata);
        inHops.add(hop);
        return;
    }
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl)))
            rConstructCplan(c, memo, tmp, inHops, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    //construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        String primitiveOpName = ((UnaryOp) hop).getOp().name();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = bop.getOp().name();
        //add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        if (bop.getOp() == OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2"))
            out = new CNodeUnary(cdata1, UnaryType.POW2);
        else if (bop.getOp() == OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2"))
            out = new CNodeUnary(cdata1, UnaryType.MULT2);
        else
            //default binary	
            out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        //add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        //construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().name()));
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), TernaryType.LOOKUP_RC1);
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggUnaryOp) {
        //aggregation handled in template implementation (note: we do not compile 
        //^2 of SUM_SQ into the operator to simplify the detection of single operators)
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggBinaryOp) {
        //(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
        if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
            CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
            out = new CNodeUnary(cdata1, UnaryType.POW2);
        } else {
            CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            if (TemplateUtils.isColVector(cdata2))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.MULT);
        }
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 2 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class TemplateOuterProduct method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    //memoization for common subexpression elimination and to avoid redundant work 
    if (tmp.containsKey(hop.getHopID()))
        return;
    //recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.OuterProdTpl);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    //construct cnode for current hop
    CNode out = null;
    if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        String primitiveOpName = ((UnaryOp) hop).getOp().toString();
        out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        String primitiveOpName = ((BinaryOp) hop).getOp().toString();
        if (HopRewriteUtils.isEqualSize(hop.getInput().get(0), hop.getInput().get(1))) {
            Hop main = hop.getInput().get((cdata1 instanceof CNodeData) ? 0 : 1);
            inHops2.put("_X", main);
        }
        //add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        //handle transpose in outer or final product
        cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
        cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
        //outer product U%*%t(V), see open
        if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
            //keep U and V for later reference
            inHops2.put("_U", hop.getInput().get(0));
            if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1)))
                inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
            else
                inHops2.put("_V", hop.getInput().get(1));
            out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
        } else //final left/right matrix mult, see close
        {
            if (cdata1.getDataType().isScalar())
                out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
            else
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    } else if (hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Direction.RowCol) {
        out = tmp.get(hop.getInput().get(0).getHopID());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) BinaryOp(org.apache.sysml.hops.BinaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 3 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rIsRowTemplateWithoutAgg.

private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
    if (visited.contains(current.getHopID()))
        return true;
    boolean ret = true;
    MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.RowTpl);
    for (int i = 0; i < 3; i++) if (me.isPlanRef(i))
        ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
    ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
    visited.add(current.getHopID());
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp)

Example 4 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method pushdownSumBinaryMult.

private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos) throws HopsException {
    //pattern:  sum(lamda*X) -> lamda*sum(X)
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // only one parent which is the sum
    ((AggUnaryOp) hi).getOp() == Hop.AggOp.SUM && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) && ((hi.getInput().get(0).getInput().get(0).getDataType() == DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType() == DataType.MATRIX) || (hi.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR))) {
        Hop operand1 = hi.getInput().get(0).getInput().get(0);
        Hop operand2 = hi.getInput().get(0).getInput().get(1);
        //check which operand is the Scalar and which is the matrix
        Hop lamda = (operand1.getDataType() == DataType.SCALAR) ? operand1 : operand2;
        Hop matrix = (operand1.getDataType() == DataType.MATRIX) ? operand1 : operand2;
        AggUnaryOp aggOp = HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol);
        Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT);
        HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
        LOG.debug("Applied pushdownSumBinaryMult.");
        return bop;
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 5 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.

/**
	 * Searches for weighted squared loss expressions and replaces them with a quaternary operator. 
	 * Currently, this search includes the following three patterns:
	 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
	 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
	 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
	 * 
	 * NOTE: We include transpose into the pattern because during runtime we need to compute
	 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
	 * without additional memory requirements for internal transpose.
	 * 
	 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
	 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
	 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
	 * as we generalized the runtime or hop/lop compilation. 
	 * 
	 * @param parent parent high-level operator
	 * @param hi high-level operator
	 * @param pos position
	 * @return high-level operator
	 * @throws HopsException if HopsException occurs
	 */
private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) throws HopsException {
    //NOTE: there might be also a general simplification without custom operator
    //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
    Hop hnew = null;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && //all patterns rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && //all patterns subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && //not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        boolean appliedPattern = false;
        //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
        if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
            Hop W = bop.getInput().get(0);
            //(X - U %*% t(V))
            Hop tmp = bop.getInput().get(1).getInput().get(0);
            if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && //prevent mv	
            HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
                //a) sum (W * (X - U %*% t(V)) ^ 2)
                int uvIndex = -1;
                if (//ba gurantees matrices
                tmp.getInput().get(1) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    uvIndex = 1;
                } else //b) sum (W * (U %*% t(V) - X) ^ 2)
                if (//ba gurantees matrices
                tmp.getInput().get(0) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
                    uvIndex = 0;
                }
                if (//rewrite match
                uvIndex >= 0) {
                    Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
                    Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
                    Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    //handle special case of post_nz
                    if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
                        W = new LiteralOp(1);
                    }
                    //construct quaternary hop
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            //a) sum ((X - W * (U %*% t(V))) ^ 2)
            int wuvIndex = -1;
            if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 1;
            } else //b) sum ((W * (U %*% t(V)) - X) ^ 2)
            if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 0;
            }
            if (//rewrite match
            wuvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
                //(W * (U %*% t(V)))
                Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
                if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
                HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && //BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    Hop W = tmp.getInput().get(0);
                    Hop U = tmp.getInput().get(1).getInput().get(0);
                    Hop V = tmp.getInput().get(1).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        //alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            //a) sum ((X - (U %*% t(V))) ^ 2)
            int uvIndex = -1;
            if (//ba gurantees matrices
            lright instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
                uvIndex = 1;
            } else //b) sum (((U %*% t(V)) - X) ^ 2)
            if (//ba gurantees matrices
            lleft instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
                uvIndex = 0;
            }
            if (//rewrite match
            uvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
                //(U %*% t(V))
                Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
                //no weighting 
                Hop W = new LiteralOp(1);
                Hop U = tmp.getInput().get(0);
                Hop V = tmp.getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(V)) {
                    V = HopRewriteUtils.createTranspose(V);
                } else {
                    V = V.getInput().get(0);
                }
                hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                HopRewriteUtils.setOutputParametersForScalar(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
            }
        }
    }
    //relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)31 Hop (org.apache.sysml.hops.Hop)29 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)15 LiteralOp (org.apache.sysml.hops.LiteralOp)13 BinaryOp (org.apache.sysml.hops.BinaryOp)11 ReorgOp (org.apache.sysml.hops.ReorgOp)10 UnaryOp (org.apache.sysml.hops.UnaryOp)10 IndexingOp (org.apache.sysml.hops.IndexingOp)5 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)5 ArrayList (java.util.ArrayList)4 DataOp (org.apache.sysml.hops.DataOp)4 TernaryOp (org.apache.sysml.hops.TernaryOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2