Search in sources :

Example 1 with CNodeOuterProduct

use of org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct in project incubator-systemml by apache.

the class TemplateOuterProduct method constructCplan.

public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
    //recursively process required cplan output
    HashSet<Hop> inHops = new HashSet<Hop>();
    HashMap<String, Hop> inHops2 = new HashMap<String, Hop>();
    HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
    hop.resetVisitStatus();
    rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
    hop.resetVisitStatus();
    //reorder inputs (ensure matrix is first input)
    Hop X = inHops2.get("_X");
    Hop U = inHops2.get("_U");
    Hop V = inHops2.get("_V");
    LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops);
    sinHops.remove(V);
    sinHops.addFirst(V);
    sinHops.remove(U);
    sinHops.addFirst(U);
    sinHops.remove(X);
    sinHops.addFirst(X);
    //construct template node
    ArrayList<CNode> inputs = new ArrayList<CNode>();
    for (Hop in : sinHops) if (in != null)
        inputs.add(tmp.get(in.getHopID()));
    CNode output = tmp.get(hop.getHopID());
    CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output);
    tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop));
    tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) && tpl.getOutProdType() == OutProdType.LEFT_OUTER_PRODUCT);
    return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
Also used : HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeOuterProduct(org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct) HashSet(java.util.HashSet) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Example 2 with CNodeOuterProduct

use of org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct in project incubator-systemml by apache.

the class SpoofCompiler method cleanupCPlans.

/**
	 * Cleanup generated cplans in order to remove unnecessary inputs created
	 * during incremental construction. This is important as it avoids unnecessary 
	 * redundant computation. 
	 * 
	 * @param cplans set of cplans
	 */
private static HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[], CNodeTpl>> cplans) {
    HashMap<Long, Pair<Hop[], CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[], CNodeTpl>>();
    for (Entry<Long, Pair<Hop[], CNodeTpl>> e : cplans.entrySet()) {
        CNodeTpl tpl = e.getValue().getValue();
        Hop[] inHops = e.getValue().getKey();
        //collect cplan leaf node names
        HashSet<Long> leafs = new HashSet<Long>();
        if (tpl instanceof CNodeMultiAgg)
            for (CNode out : ((CNodeMultiAgg) tpl).getOutputs()) rCollectLeafIDs(out, leafs);
        else
            rCollectLeafIDs(tpl.getOutput(), leafs);
        //create clean cplan w/ minimal inputs
        if (inHops.length == leafs.size())
            cplans2.put(e.getKey(), e.getValue());
        else {
            tpl.cleanupInputs(leafs);
            ArrayList<Hop> tmp = new ArrayList<Hop>();
            for (Hop hop : inHops) {
                if (hop != null && leafs.contains(hop.getHopID()))
                    tmp.add(hop);
            }
            cplans2.put(e.getKey(), new Pair<Hop[], CNodeTpl>(tmp.toArray(new Hop[0]), tpl));
        }
        //remove invalid plans with column indexing on main input
        if (tpl instanceof CNodeCell) {
            CNodeData in1 = (CNodeData) tpl.getInput().get(0);
            if (rHasLookupRC1(tpl.getOutput(), in1) || isLookupRC1(tpl.getOutput(), in1)) {
                cplans2.remove(e.getKey());
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
            }
        } else if (tpl instanceof CNodeMultiAgg) {
            CNodeData in1 = (CNodeData) tpl.getInput().get(0);
            for (CNode output : ((CNodeMultiAgg) tpl).getOutputs()) if (rHasLookupRC1(output, in1) || isLookupRC1(output, in1)) {
                cplans2.remove(e.getKey());
                if (LOG.isTraceEnabled())
                    LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
            }
        }
        //remove spurious lookups on main input of cell template
        if (tpl instanceof CNodeCell || tpl instanceof CNodeOuterProduct) {
            CNodeData in1 = (CNodeData) tpl.getInput().get(0);
            rFindAndRemoveLookup(tpl.getOutput(), in1);
        } else if (tpl instanceof CNodeMultiAgg) {
            CNodeData in1 = (CNodeData) tpl.getInput().get(0);
            rFindAndRemoveLookupMultiAgg((CNodeMultiAgg) tpl, in1);
        }
        //remove cplan w/ single op and w/o agg
        if ((tpl instanceof CNodeCell && ((((CNodeCell) tpl).getCellType() == CellType.NO_AGG && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl))) || tpl instanceof CNodeRow && TemplateUtils.hasSingleOperation(tpl))
            cplans2.remove(e.getKey());
        //remove cplan if empty
        if (tpl.getOutput() instanceof CNodeData)
            cplans2.remove(e.getKey());
    }
    return cplans2;
}
Also used : CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) CNodeTpl(org.apache.sysml.hops.codegen.cplan.CNodeTpl) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CNodeCell(org.apache.sysml.hops.codegen.cplan.CNodeCell) CNode(org.apache.sysml.hops.codegen.cplan.CNode) CNodeOuterProduct(org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct) CNodeRow(org.apache.sysml.hops.codegen.cplan.CNodeRow) CNodeMultiAgg(org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg) Pair(org.apache.sysml.runtime.matrix.data.Pair) HashSet(java.util.HashSet)

Example 3 with CNodeOuterProduct

use of org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct in project incubator-systemml by apache.

the class SpoofCompiler method rConstructModifiedHopDag.

private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans, HashMap<Long, Pair<Hop[], Class<?>>> clas, HashSet<Long> memo) {
    if (memo.contains(hop.getHopID()))
        //already processed
        return;
    Hop hnew = hop;
    if (clas.containsKey(hop.getHopID())) {
        //replace sub-dag with generated operator
        Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID());
        CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue();
        hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), tmpCla.getValue(), false, tmpCNode.getOutputDimType());
        Hop[] inHops = tmpCla.getKey();
        for (int i = 0; i < inHops.length; i++) {
            if (tmpCNode instanceof CNodeOuterProduct && inHops[i].getHopID() == ((CNodeData) tmpCNode.getInput().get(2)).getHopID() && !TemplateUtils.hasTransposeParentUnderOuterProduct(inHops[i])) {
                hnew.addInput(HopRewriteUtils.createTranspose(inHops[i]));
            } else
                //add inputs
                hnew.addInput(inHops[i]);
        }
        //modify output parameters 
        HopRewriteUtils.setOutputParameters(hnew, hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
        if (tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct) tmpCNode).isTransposeOutput())
            hnew = HopRewriteUtils.createTranspose(hnew);
        else if (tmpCNode instanceof CNodeMultiAgg) {
            ArrayList<Hop> roots = ((CNodeMultiAgg) tmpCNode).getRootNodes();
            hnew.setDataType(DataType.MATRIX);
            HopRewriteUtils.setOutputParameters(hnew, 1, roots.size(), inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1);
            //inject artificial right indexing operations for all parents of all nodes
            for (int i = 0; i < roots.size(); i++) {
                Hop hnewi = HopRewriteUtils.createScalarIndexing(hnew, 1, i + 1);
                HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
            }
        } else if (tmpCNode instanceof CNodeCell && ((CNodeCell) tmpCNode).requiredCastDtm()) {
            HopRewriteUtils.setOutputParametersForScalar(hnew);
            hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
        }
        if (!(tmpCNode instanceof CNodeMultiAgg))
            HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
        memo.add(hnew.getHopID());
    }
    //process hops recursively (parent-child links modified)
    for (int i = 0; i < hnew.getInput().size(); i++) {
        Hop c = hnew.getInput().get(i);
        rConstructModifiedHopDag(c, cplans, clas, memo);
    }
    memo.add(hnew.getHopID());
}
Also used : CNodeOuterProduct(org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct) CNodeTpl(org.apache.sysml.hops.codegen.cplan.CNodeTpl) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) CNodeMultiAgg(org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg) CNodeCell(org.apache.sysml.hops.codegen.cplan.CNodeCell)

Aggregations

ArrayList (java.util.ArrayList)3 Hop (org.apache.sysml.hops.Hop)3 CNodeOuterProduct (org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct)3 HashMap (java.util.HashMap)2 HashSet (java.util.HashSet)2 CNode (org.apache.sysml.hops.codegen.cplan.CNode)2 CNodeCell (org.apache.sysml.hops.codegen.cplan.CNodeCell)2 CNodeMultiAgg (org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg)2 CNodeTpl (org.apache.sysml.hops.codegen.cplan.CNodeTpl)2 Pair (org.apache.sysml.runtime.matrix.data.Pair)2 LinkedHashMap (java.util.LinkedHashMap)1 LinkedList (java.util.LinkedList)1 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)1 CNodeRow (org.apache.sysml.hops.codegen.cplan.CNodeRow)1