Search in sources :

Example 31 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class TemplateRow method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    // recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof AggUnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
            if (hop.getInput().get(0).getDim2() == 1)
                out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            else {
                String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            }
        } else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            // vector add without temporary copy
            if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
                out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
            else
                out = cdata1;
        } else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
        }
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
            // correct input under transpose
            cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
            inHops.remove(hop.getInput().get(0));
            if (cdata1 instanceof CNodeData)
                inHops.add(hop.getInput().get(0).getInput().get(0));
            // note: vectorMultAdd applicable to vector-scalar, and vector-vector
            if (hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
            else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
                if (!inHops2.containsKey("B1")) {
                    // incl modification of X for consistency
                    if (cdata1 instanceof CNodeData)
                        inHops2.put("X", hop.getInput().get(0).getInput().get(0));
                    inHops2.put("B1", hop.getInput().get(1));
                }
            }
            if (!inHops2.containsKey("X"))
                inHops2.put("X", hop.getInput().get(0).getInput().get(0));
        } else {
            if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
            else if (hop.getInput().get(1).getDim2() == 1) {
                out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
                inHops2.put("X", hop.getInput().get(0));
            } else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
                inHops2.put("X", hop.getInput().get(0));
                inHops2.put("B1", hop.getInput().get(1));
            }
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
        if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
            inHops.add(hop.getInput().get(0));
    } else if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
            if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
                String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            } else
                throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
        } else // general scalar case
        {
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            String primitiveOpName = ((UnaryOp) hop).getOp().toString();
            out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
        }
    } else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
        // special case for cbind with zeros
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = null;
        if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
            cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
            // rm 0-matrix
            inHops.remove(hop.getInput().get(1));
        } else {
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        }
        out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
        if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
            inHops2.put("X", hop.getInput().get(0));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
        (hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
            if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
                if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                } else {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
                    if (TemplateUtils.isColVector(cdata1))
                        cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
                    if (TemplateUtils.isColVector(cdata2))
                        cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                }
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else
                throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
        } else // one input is a vector/scalar other is a scalar
        {
            String primitiveOpName = ((BinaryOp) hop).getOp().toString();
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            if (// vector or vector can be inferred from lhs
            TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
        }
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        // construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
    } else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
        CNode[] inputs = new CNode[hop.getInput().size()];
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop c = hop.getInput().get(i);
            CNode cdata = tmp.get(c.getHopID());
            if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
                cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
            inputs[i] = cdata;
            if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
                inHops2.put("X", c);
        }
        out = new CNodeNary(inputs, NaryType.VECT_CBIND);
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
    }
    if (out == null) {
        throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
    }
    if (out.getDataType().isMatrix()) {
        out.setNumRows(hop.getDim1());
        out.setNumCols(hop.getDim2());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) CNodeNary(org.apache.sysml.hops.codegen.cplan.CNodeNary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 32 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteCompressedReblock method rAnalyzeHopDag.

private static void rAnalyzeHopDag(Hop current, ProbeStatus status) {
    if (current.isVisited())
        return;
    // process children recursively
    for (Hop input : current.getInput()) rAnalyzeHopDag(input, status);
    // handle source persistent read
    if (current.getHopID() == status.startHopID) {
        status.compMtx.add(getTmpName(current));
        status.foundStart = true;
    }
    // a) handle function calls
    if (current instanceof FunctionOp && hasCompressedInput(current, status)) {
        // TODO handle of functions in a more fine-grained manner
        // to cover special cases multiple calls where compressed
        // inputs might occur for different input parameters
        FunctionOp fop = (FunctionOp) current;
        String fkey = fop.getFunctionKey();
        if (!status.procFn.contains(fkey)) {
            // memoization to avoid redundant analysis and recursive calls
            status.procFn.add(fkey);
            // map inputs to function inputs
            FunctionStatementBlock fsb = status.prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
            ProbeStatus status2 = new ProbeStatus(status);
            for (int i = 0; i < fop.getInput().size(); i++) if (status.compMtx.contains(getTmpName(fop.getInput().get(i))))
                status2.compMtx.add(fstmt.getInputParams().get(i).getName());
            // analyze function and merge meta info
            rAnalyzeProgram(fsb, status2);
            status.foundStart |= status2.foundStart;
            status.usedInLoop |= status2.usedInLoop;
            status.condUpdate |= status2.condUpdate;
            status.nonApplicable |= status2.nonApplicable;
            // map function outputs to outputs
            String[] outputs = fop.getOutputVariableNames();
            for (int i = 0; i < outputs.length; i++) if (status2.compMtx.contains(fstmt.getOutputParams().get(i).getName()))
                status.compMtx.add(outputs[i]);
        }
    } else // b) handle transient reads and writes (name mapping)
    if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) && status.compMtx.contains(getTmpName(current.getInput().get(0))))
        status.compMtx.add(current.getName());
    else if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) && status.compMtx.contains(current.getName()))
        status.compMtx.add(getTmpName(current));
    else // c) handle applicable operations
    if (hasCompressedInput(current, status)) {
        // valid with uncompressed outputs
        boolean compUCOut = (// tsmm
        current instanceof AggBinaryOp && current.getDim2() <= current.getColsInBlock() && ((AggBinaryOp) current).checkTransposeSelf() == MMTSJType.LEFT) || // mvmm
        (current instanceof AggBinaryOp && (current.getDim1() == 1 || current.getDim2() == 1)) || (HopRewriteUtils.isTransposeOperation(current) && current.getParent().size() == 1 && current.getParent().get(0) instanceof AggBinaryOp && (current.getParent().get(0).getDim1() == 1 || current.getParent().get(0).getDim2() == 1)) || HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX);
        // valid with compressed outputs
        boolean compCOut = HopRewriteUtils.isBinaryMatrixScalarOperation(current) || HopRewriteUtils.isBinary(current, OpOp2.CBIND);
        boolean metaOp = HopRewriteUtils.isUnary(current, OpOp1.NROW, OpOp1.NCOL);
        status.nonApplicable |= !(compUCOut || compCOut || metaOp);
        if (compCOut)
            status.compMtx.add(getTmpName(current));
    }
    current.setVisited();
}
Also used : FunctionStatement(org.apache.sysml.parser.FunctionStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) FunctionOp(org.apache.sysml.hops.FunctionOp)

Example 33 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteSplitDagDataDependentOperators method rCollectDataDependentOperators.

private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> cand) {
    if (hop.isVisited())
        return;
    // prevent unnecessary dag split (dims known or no consumer operations)
    boolean noSplitRequired = (hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true));
    boolean investigateChilds = true;
    // #1 removeEmpty
    if (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.RMEMPTY && !noSplitRequired && !(hop.getParent().size() == 1 && hop.getParent().get(0) instanceof TernaryOp && ((TernaryOp) hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
        ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp) hop;
        cand.add(pbhop);
        investigateChilds = false;
        // keep interesting consumer information, flag hops accordingly
        boolean noEmptyBlocks = true;
        boolean onlyPMM = true;
        boolean diagInput = pbhop.isTargetDiagInput();
        for (Hop p : hop.getParent()) {
            // list of operators without need for empty blocks to be extended as needed
            noEmptyBlocks &= (p instanceof AggBinaryOp && hop == p.getInput().get(0) || HopRewriteUtils.isUnary(p, OpOp1.NROW));
            onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
        }
        pbhop.setOutputEmptyBlocks(!noEmptyBlocks);
        if (onlyPMM && diagInput) {
            if (ConfigurationManager.isDynamicRecompilation())
                pbhop.setOutputPermutationMatrix(true);
            for (Hop p : hop.getParent()) ((AggBinaryOp) p).setHasLeftPMInput(true);
        }
    }
    // #2 ctable with unknown dims
    if (HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) && // dims not provided
    hop.getInput().size() < 4 && !noSplitRequired) {
        cand.add(hop);
        investigateChilds = false;
        // keep interesting consumer information, flag hops accordingly
        boolean onlyPMM = true;
        for (Hop p : hop.getParent()) {
            onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
        }
        if (onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0)))
            hop.setOutputEmptyBlocks(false);
    }
    // #3 orderby childs computed in same DAG
    if (HopRewriteUtils.isReorg(hop, ReOrgOp.SORT)) {
        // params 'decreasing' / 'indexreturn'
        for (int i = 2; i <= 3; i++) {
            Hop c = hop.getInput().get(i);
            if (!(c instanceof LiteralOp || c instanceof DataOp)) {
                cand.add(c);
                c.setVisited();
                investigateChilds = false;
            }
        }
    }
    // #4 second-order eval function
    if (HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired) {
        cand.add(hop);
        investigateChilds = false;
    }
    // otherwise, processed by recursive rule application)
    if (investigateChilds && hop.getInput() != null)
        for (Hop c : hop.getInput()) rCollectDataDependentOperators(c, cand);
    hop.setVisited();
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 34 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedCrossEntropy.

private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // pattern rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && // pattern subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        // Pattern 1) sum( X * log(U %*% t(V)))
        if (bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && // prevent mb
        HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && // ba gurantees matrices
        right.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, new LiteralOp(0.0), 0, false, false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedCEMM (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) sum( X * log(U %*% t(V) + eps))
        if (!appliedPattern && bop.getOp() == OpOp2.MULT && left.getDataType() == DataType.MATRIX && HopRewriteUtils.isEqualSize(left, right) && HopRewriteUtils.isUnary(right, OpOp1.LOG) && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && right.getInput().get(0).getInput().get(1) instanceof LiteralOp && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop X = left;
            Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1);
            Hop eps = right.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V, eps, 1, false, // 1 => BASIC_EPS
            false);
            hnew.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
            LOG.debug("Applied simplifyWeightedCEMMEps (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 35 with AggBinaryOp

use of org.apache.sysml.hops.AggBinaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.

private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    boolean appliedPattern = false;
    // Pattern 1) (W*uop(U%*%t(V)))
    if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
    HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
        Hop W = hi.getInput().get(0);
        Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
        Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
        boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
        OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
        if (!HopRewriteUtils.isTransposeOperation(V))
            V = HopRewriteUtils.createTranspose(V);
        else
            V = V.getInput().get(0);
        hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
        hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
        hnew.refreshSizeInformation();
        appliedPattern = true;
        LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
    }
    // Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
        // non-literal
        final Hop nl;
        if (hi.getInput().get(0) instanceof LiteralOp) {
            nl = hi.getInput().get(1);
        } else {
            nl = hi.getInput().get(0);
        }
        if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
        nl.getParent().size() == 1 && // prevent mv
        HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
        nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
        (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
            final Hop W = nl.getInput().get(0);
            final Hop U = nl.getInput().get(1).getInput().get(0);
            Hop V = nl.getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
        }
    }
    // Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
    if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
    hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
        Hop left = hi.getInput().get(1).getInput().get(0);
        Hop right = hi.getInput().get(1).getInput().get(1);
        Hop abop = null;
        // pattern 2a) matrix-scalar operations
        if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
            abop = left;
        } else // pattern 2b) scalar-matrix operations
        if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
        HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
        HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
            abop = right;
        }
        if (abop != null) {
            Hop W = hi.getInput().get(0);
            Hop U = abop.getInput().get(0);
            Hop V = abop.getInput().get(1);
            boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
            OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
            if (!HopRewriteUtils.isTransposeOperation(V))
                V = HopRewriteUtils.createTranspose(V);
            else
                V = V.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) OpOp1(org.apache.sysml.hops.Hop.OpOp1) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) OpOp2(org.apache.sysml.hops.Hop.OpOp2) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)56 Hop (org.apache.sysml.hops.Hop)47 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)32 BinaryOp (org.apache.sysml.hops.BinaryOp)27 LiteralOp (org.apache.sysml.hops.LiteralOp)21 ReorgOp (org.apache.sysml.hops.ReorgOp)15 UnaryOp (org.apache.sysml.hops.UnaryOp)15 TernaryOp (org.apache.sysml.hops.TernaryOp)11 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)11 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)10 IndexingOp (org.apache.sysml.hops.IndexingOp)9 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)9 CNode (org.apache.sysml.hops.codegen.cplan.CNode)8 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)6 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)6 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)6 DataOp (org.apache.sysml.hops.DataOp)4 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4