use of org.apache.sysml.hops.codegen.cplan.CNodeRow in project incubator-systemml by apache.
the class SpoofCompiler method rConstructModifiedHopDag.
private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans, HashMap<Long, Pair<Hop[], Class<?>>> clas, HashSet<Long> memo) {
if (memo.contains(hop.getHopID()))
// already processed
return;
Hop hnew = hop;
if (clas.containsKey(hop.getHopID())) {
// replace sub-dag with generated operator
Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID());
CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue();
hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), tmpCla.getValue(), false, tmpCNode.getOutputDimType());
Hop[] inHops = tmpCla.getKey();
for (int i = 0; i < inHops.length; i++) {
if (tmpCNode instanceof CNodeOuterProduct && inHops[i].getHopID() == ((CNodeData) tmpCNode.getInput().get(2)).getHopID() && !TemplateUtils.hasTransposeParentUnderOuterProduct(inHops[i])) {
hnew.addInput(HopRewriteUtils.createTranspose(inHops[i]));
} else
// add inputs
hnew.addInput(inHops[i]);
}
// modify output parameters
HopRewriteUtils.setOutputParameters(hnew, hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
if (tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct) tmpCNode).isTransposeOutput())
hnew = HopRewriteUtils.createTranspose(hnew);
else if (tmpCNode instanceof CNodeMultiAgg) {
ArrayList<Hop> roots = ((CNodeMultiAgg) tmpCNode).getRootNodes();
hnew.setDataType(DataType.MATRIX);
HopRewriteUtils.setOutputParameters(hnew, 1, roots.size(), inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1);
// inject artificial right indexing operations for all parents of all nodes
for (int i = 0; i < roots.size(); i++) {
Hop hnewi = (roots.get(i) instanceof AggUnaryOp) ? HopRewriteUtils.createScalarIndexing(hnew, 1, i + 1) : HopRewriteUtils.createIndexingOp(hnew, 1, i + 1);
HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
}
} else if (tmpCNode instanceof CNodeCell && ((CNodeCell) tmpCNode).requiredCastDtm()) {
HopRewriteUtils.setOutputParametersForScalar(hnew);
hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
} else if (tmpCNode instanceof CNodeRow && (((CNodeRow) tmpCNode).getRowType() == RowType.NO_AGG_CONST || ((CNodeRow) tmpCNode).getRowType() == RowType.COL_AGG_CONST))
((SpoofFusedOp) hnew).setConstDim2(((CNodeRow) tmpCNode).getConstDim2());
if (!(tmpCNode instanceof CNodeMultiAgg))
HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
memo.add(hnew.getHopID());
}
// process hops recursively (parent-child links modified)
for (int i = 0; i < hnew.getInput().size(); i++) {
Hop c = hnew.getInput().get(i);
rConstructModifiedHopDag(c, cplans, clas, memo);
}
memo.add(hnew.getHopID());
}
use of org.apache.sysml.hops.codegen.cplan.CNodeRow in project systemml by apache.
the class SpoofCompiler method rConstructModifiedHopDag.
private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[], CNodeTpl>> cplans, HashMap<Long, Pair<Hop[], Class<?>>> clas, HashSet<Long> memo) {
if (memo.contains(hop.getHopID()))
// already processed
return;
Hop hnew = hop;
if (clas.containsKey(hop.getHopID())) {
// replace sub-dag with generated operator
Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID());
CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue();
hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(), tmpCla.getValue(), false, tmpCNode.getOutputDimType());
Hop[] inHops = tmpCla.getKey();
for (int i = 0; i < inHops.length; i++) {
if (tmpCNode instanceof CNodeOuterProduct && inHops[i].getHopID() == ((CNodeData) tmpCNode.getInput().get(2)).getHopID() && !TemplateUtils.hasTransposeParentUnderOuterProduct(inHops[i])) {
hnew.addInput(HopRewriteUtils.createTranspose(inHops[i]));
} else
// add inputs
hnew.addInput(inHops[i]);
}
// modify output parameters
HopRewriteUtils.setOutputParameters(hnew, hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz());
if (tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct) tmpCNode).isTransposeOutput())
hnew = HopRewriteUtils.createTranspose(hnew);
else if (tmpCNode instanceof CNodeMultiAgg) {
ArrayList<Hop> roots = ((CNodeMultiAgg) tmpCNode).getRootNodes();
hnew.setDataType(DataType.MATRIX);
HopRewriteUtils.setOutputParameters(hnew, 1, roots.size(), inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1);
// inject artificial right indexing operations for all parents of all nodes
for (int i = 0; i < roots.size(); i++) {
Hop hnewi = (roots.get(i) instanceof AggUnaryOp) ? HopRewriteUtils.createScalarIndexing(hnew, 1, i + 1) : HopRewriteUtils.createIndexingOp(hnew, 1, i + 1);
HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
}
} else if (tmpCNode instanceof CNodeCell && ((CNodeCell) tmpCNode).requiredCastDtm()) {
HopRewriteUtils.setOutputParametersForScalar(hnew);
hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
} else if (tmpCNode instanceof CNodeRow && (((CNodeRow) tmpCNode).getRowType() == RowType.NO_AGG_CONST || ((CNodeRow) tmpCNode).getRowType() == RowType.COL_AGG_CONST))
((SpoofFusedOp) hnew).setConstDim2(((CNodeRow) tmpCNode).getConstDim2());
if (!(tmpCNode instanceof CNodeMultiAgg))
HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
memo.add(hnew.getHopID());
}
// process hops recursively (parent-child links modified)
for (int i = 0; i < hnew.getInput().size(); i++) {
Hop c = hnew.getInput().get(i);
rConstructModifiedHopDag(c, cplans, clas, memo);
}
memo.add(hnew.getHopID());
}
Aggregations