use of org.apache.sysml.hops.codegen.cplan.CNodeCell in project incubator-systemml by apache.
the class TemplateCell method constructCplan.
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
//recursively process required cplan output
HashSet<Hop> inHops = new HashSet<Hop>();
HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
hop.resetVisitStatus();
//reorder inputs (ensure matrices/vectors come first) and prune literals
//note: we order by number of cells and subsequently sparsity to ensure
//that sparse inputs are used as the main input w/o unnecessary conversion
List<Hop> sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).collect(Collectors.toList());
//construct template node
ArrayList<CNode> inputs = new ArrayList<CNode>();
for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeCell tpl = new CNodeCell(inputs, output);
tpl.setCellType(TemplateUtils.getCellType(hop));
tpl.setAggOp(TemplateUtils.getAggOp(hop));
tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops.get(0))) || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0)));
tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
// return cplan instance
return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
use of org.apache.sysml.hops.codegen.cplan.CNodeCell in project incubator-systemml by apache.
the class SpoofCompiler method cleanupCPlans.
/**
* Cleanup generated cplans in order to remove unnecessary inputs created
* during incremental construction. This is important as it avoids unnecessary
* redundant computation.
*
* @param cplans set of cplans
*/
private static HashMap<Long, Pair<Hop[], CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[], CNodeTpl>> cplans) {
HashMap<Long, Pair<Hop[], CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[], CNodeTpl>>();
for (Entry<Long, Pair<Hop[], CNodeTpl>> e : cplans.entrySet()) {
CNodeTpl tpl = e.getValue().getValue();
Hop[] inHops = e.getValue().getKey();
//collect cplan leaf node names
HashSet<Long> leafs = new HashSet<Long>();
if (tpl instanceof CNodeMultiAgg)
for (CNode out : ((CNodeMultiAgg) tpl).getOutputs()) rCollectLeafIDs(out, leafs);
else
rCollectLeafIDs(tpl.getOutput(), leafs);
//create clean cplan w/ minimal inputs
if (inHops.length == leafs.size())
cplans2.put(e.getKey(), e.getValue());
else {
tpl.cleanupInputs(leafs);
ArrayList<Hop> tmp = new ArrayList<Hop>();
for (Hop hop : inHops) {
if (hop != null && leafs.contains(hop.getHopID()))
tmp.add(hop);
}
cplans2.put(e.getKey(), new Pair<Hop[], CNodeTpl>(tmp.toArray(new Hop[0]), tpl));
}
//remove invalid plans with column indexing on main input
if (tpl instanceof CNodeCell) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
if (rHasLookupRC1(tpl.getOutput(), in1) || isLookupRC1(tpl.getOutput(), in1)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
}
} else if (tpl instanceof CNodeMultiAgg) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
for (CNode output : ((CNodeMultiAgg) tpl).getOutputs()) if (rHasLookupRC1(output, in1) || isLookupRC1(output, in1)) {
cplans2.remove(e.getKey());
if (LOG.isTraceEnabled())
LOG.trace("Removed cplan due to invalid rc1 indexing on main input.");
}
}
//remove spurious lookups on main input of cell template
if (tpl instanceof CNodeCell || tpl instanceof CNodeOuterProduct) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
rFindAndRemoveLookup(tpl.getOutput(), in1);
} else if (tpl instanceof CNodeMultiAgg) {
CNodeData in1 = (CNodeData) tpl.getInput().get(0);
rFindAndRemoveLookupMultiAgg((CNodeMultiAgg) tpl, in1);
}
//remove cplan w/ single op and w/o agg
if ((tpl instanceof CNodeCell && ((((CNodeCell) tpl).getCellType() == CellType.NO_AGG && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl))) || tpl instanceof CNodeRow && TemplateUtils.hasSingleOperation(tpl))
cplans2.remove(e.getKey());
//remove cplan if empty
if (tpl.getOutput() instanceof CNodeData)
cplans2.remove(e.getKey());
}
return cplans2;
}
use of org.apache.sysml.hops.codegen.cplan.CNodeCell in project incubator-systemml by apache.
the class SpoofCompiler method rConstructModifiedHopDag.
private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans, HashMap<Long, Pair<Hop[], Class<?>>> clas, HashSet<Long> memo) {
if (memo.contains(hop.getHopID()))
//already processed
return;
Hop hnew = hop;
if (clas.containsKey(hop.getHopID())) {
//replace sub-dag with generated operator
Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID());
CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue();
hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), tmpCla.getValue(), false, tmpCNode.getOutputDimType());
Hop[] inHops = tmpCla.getKey();
for (int i = 0; i < inHops.length; i++) {
if (tmpCNode instanceof CNodeOuterProduct && inHops[i].getHopID() == ((CNodeData) tmpCNode.getInput().get(2)).getHopID() && !TemplateUtils.hasTransposeParentUnderOuterProduct(inHops[i])) {
hnew.addInput(HopRewriteUtils.createTranspose(inHops[i]));
} else
//add inputs
hnew.addInput(inHops[i]);
}
//modify output parameters
HopRewriteUtils.setOutputParameters(hnew, hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
if (tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct) tmpCNode).isTransposeOutput())
hnew = HopRewriteUtils.createTranspose(hnew);
else if (tmpCNode instanceof CNodeMultiAgg) {
ArrayList<Hop> roots = ((CNodeMultiAgg) tmpCNode).getRootNodes();
hnew.setDataType(DataType.MATRIX);
HopRewriteUtils.setOutputParameters(hnew, 1, roots.size(), inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1);
//inject artificial right indexing operations for all parents of all nodes
for (int i = 0; i < roots.size(); i++) {
Hop hnewi = HopRewriteUtils.createScalarIndexing(hnew, 1, i + 1);
HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
}
} else if (tmpCNode instanceof CNodeCell && ((CNodeCell) tmpCNode).requiredCastDtm()) {
HopRewriteUtils.setOutputParametersForScalar(hnew);
hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
}
if (!(tmpCNode instanceof CNodeMultiAgg))
HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
memo.add(hnew.getHopID());
}
//process hops recursively (parent-child links modified)
for (int i = 0; i < hnew.getInput().size(); i++) {
Hop c = hnew.getInput().get(i);
rConstructModifiedHopDag(c, cplans, clas, memo);
}
memo.add(hnew.getHopID());
}
Aggregations