Search in sources :

Example 1 with AggOp

use of org.apache.sysml.hops.Hop.AggOp in project incubator-systemml by apache.

the class CNodeMultiAgg method codegen.

@Override
public String codegen(boolean sparse) {
    // note: ignore sparse flag, generate both
    String tmp = TEMPLATE;
    //rename inputs
    // input matrix
    rReplaceDataNode(_outputs, _inputs.get(0), "a");
    renameInputs(_outputs, _inputs, 1);
    //generate dense/sparse bodies
    StringBuilder sb = new StringBuilder();
    for (CNode out : _outputs) sb.append(out.codegen(false));
    for (CNode out : _outputs) out.resetGenerated();
    //append output assignments
    for (int i = 0; i < _outputs.size(); i++) {
        CNode out = _outputs.get(i);
        String tmpOut = getAggTemplate(i);
        //get variable name (w/ handling of direct consumption of inputs)
        String varName = (out instanceof CNodeData && ((CNodeData) out).getHopID() == ((CNodeData) _inputs.get(0)).getHopID()) ? "a" : out.getVarname();
        tmpOut = tmpOut.replace("%IN%", varName);
        tmpOut = tmpOut.replace("%IX%", String.valueOf(i));
        sb.append(tmpOut);
    }
    //replace class name and body
    tmp = tmp.replaceAll("%TMP%", createVarname());
    tmp = tmp.replaceAll("%BODY_dense%", sb.toString());
    //replace meta data information
    String aggList = "";
    for (AggOp aggOp : _aggOps) {
        aggList += !aggList.isEmpty() ? "," : "";
        aggList += "AggOp." + aggOp.name();
    }
    tmp = tmp.replaceAll("%AGG_OP%", aggList);
    return tmp;
}
Also used : AggOp(org.apache.sysml.hops.Hop.AggOp)

Example 2 with AggOp

use of org.apache.sysml.hops.Hop.AggOp in project incubator-systemml by apache.

the class RewriteForLoopVectorization method vectorizeScalarAggregate.

private StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) throws HopsException {
    StatementBlock ret = sb;
    //check missing and supported increment values
    if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
        return ret;
    }
    //check for applicability
    boolean leftScalar = false;
    boolean rightScalar = false;
    //row or col
    boolean rowIx = false;
    if (csb.get_hops() != null && csb.get_hops().size() == 1) {
        Hop root = csb.get_hops().get(0);
        if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
            BinaryOp bop = (BinaryOp) root.getInput().get(0);
            Hop left = bop.getInput().get(0);
            Hop right = bop.getInput().get(1);
            //check for left scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) right.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = false;
                }
            } else //check for right scalar plus
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
                IndexingOp ix = (IndexingOp) left.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = false;
                }
            }
        }
    }
    //apply rewrite if possible
    if (leftScalar || rightScalar) {
        Hop root = csb.get_hops().get(0);
        BinaryOp bop = (BinaryOp) root.getInput().get(0);
        Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
        Hop ix = cast.getInput().get(0);
        int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
        AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
        //replace cast with sum
        AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
        HopRewriteUtils.removeChildReference(cast, ix);
        HopRewriteUtils.removeChildReference(bop, cast);
        HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
        //modify indexing expression according to loop predicate from-to
        //NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
        int index1 = rowIx ? 1 : 3;
        int index2 = rowIx ? 2 : 4;
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
        HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
        //update indexing size information
        if (rowIx)
            ((IndexingOp) ix).setRowLowerEqualsUpper(false);
        else
            ((IndexingOp) ix).setColLowerEqualsUpper(false);
        ix.refreshSizeInformation();
        ret = csb;
        LOG.debug("Applied vectorizeScalarSumForLoop.");
    }
    return ret;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) LeftIndexingOp(org.apache.sysml.hops.LeftIndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggOp(org.apache.sysml.hops.Hop.AggOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 3 with AggOp

use of org.apache.sysml.hops.Hop.AggOp in project incubator-systemml by apache.

the class TemplateMultiAgg method constructCplan.

public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    //get all root nodes for multi aggregation
    MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MultiAggTpl);
    ArrayList<Hop> roots = new ArrayList<Hop>();
    for (int i = 0; i < 3; i++) if (multiAgg.isPlanRef(i))
        roots.add(memo._hopRefs.get(multiAgg.input(i)));
    Hop.resetVisitStatus(roots);
    //recursively process required cplan outputs
    HashSet<Hop> inHops = new HashSet<Hop>();
    HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
    for (//use celltpl cplan construction
    Hop root : //use celltpl cplan construction
    roots) super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
    Hop.resetVisitStatus(roots);
    //reorder inputs (ensure matrices/vectors come first) and prune literals
    //note: we order by number of cells and subsequently sparsity to ensure
    //that sparse inputs are used as the main input w/o unnecessary conversion
    List<Hop> sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).collect(Collectors.toList());
    //construct template node
    ArrayList<CNode> inputs = new ArrayList<CNode>();
    for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
    ArrayList<CNode> outputs = new ArrayList<CNode>();
    ArrayList<AggOp> aggOps = new ArrayList<AggOp>();
    for (Hop root : roots) {
        CNode node = tmp.get(root.getHopID());
        if (//add indexing ops for sideways data inputs
        node instanceof CNodeData && ((CNodeData) inputs.get(0)).getHopID() != ((CNodeData) node).getHopID())
            node = new CNodeUnary(node, (roots.get(0).getDim2() == 1) ? UnaryType.LOOKUP_R : UnaryType.LOOKUP_RC);
        outputs.add(node);
        aggOps.add(TemplateUtils.getAggOp(root));
    }
    CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
    tpl.setAggOps(aggOps);
    tpl.setRootNodes(roots);
    // return cplan instance
    return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
Also used : CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) HashMap(java.util.HashMap) AggOp(org.apache.sysml.hops.Hop.AggOp) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) CNodeMultiAgg(org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Aggregations

AggOp (org.apache.sysml.hops.Hop.AggOp)3 Hop (org.apache.sysml.hops.Hop)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)1 BinaryOp (org.apache.sysml.hops.BinaryOp)1 DataOp (org.apache.sysml.hops.DataOp)1 IndexingOp (org.apache.sysml.hops.IndexingOp)1 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)1 LiteralOp (org.apache.sysml.hops.LiteralOp)1 UnaryOp (org.apache.sysml.hops.UnaryOp)1 CNode (org.apache.sysml.hops.codegen.cplan.CNode)1 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)1 CNodeMultiAgg (org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg)1 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)1 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)1 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)1 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)1 StatementBlock (org.apache.sysml.parser.StatementBlock)1