use of org.apache.sysml.hops.Hop.AggOp in project incubator-systemml by apache.
the class CNodeMultiAgg method codegen.
@Override
public String codegen(boolean sparse) {
// note: ignore sparse flag, generate both
String tmp = TEMPLATE;
//rename inputs
// input matrix
rReplaceDataNode(_outputs, _inputs.get(0), "a");
renameInputs(_outputs, _inputs, 1);
//generate dense/sparse bodies
StringBuilder sb = new StringBuilder();
for (CNode out : _outputs) sb.append(out.codegen(false));
for (CNode out : _outputs) out.resetGenerated();
//append output assignments
for (int i = 0; i < _outputs.size(); i++) {
CNode out = _outputs.get(i);
String tmpOut = getAggTemplate(i);
//get variable name (w/ handling of direct consumption of inputs)
String varName = (out instanceof CNodeData && ((CNodeData) out).getHopID() == ((CNodeData) _inputs.get(0)).getHopID()) ? "a" : out.getVarname();
tmpOut = tmpOut.replace("%IN%", varName);
tmpOut = tmpOut.replace("%IX%", String.valueOf(i));
sb.append(tmpOut);
}
//replace class name and body
tmp = tmp.replaceAll("%TMP%", createVarname());
tmp = tmp.replaceAll("%BODY_dense%", sb.toString());
//replace meta data information
String aggList = "";
for (AggOp aggOp : _aggOps) {
aggList += !aggList.isEmpty() ? "," : "";
aggList += "AggOp." + aggOp.name();
}
tmp = tmp.replaceAll("%AGG_OP%", aggList);
return tmp;
}
use of org.apache.sysml.hops.Hop.AggOp in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeScalarAggregate.
private StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) throws HopsException {
StatementBlock ret = sb;
//check missing and supported increment values
if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
//check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
//row or col
boolean rowIx = false;
if (csb.get_hops() != null && csb.get_hops().size() == 1) {
Hop root = csb.get_hops().get(0);
if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//check for left scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) right.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
leftScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
leftScalar = true;
rowIx = false;
}
} else //check for right scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) left.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
rightScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
rightScalar = true;
rowIx = false;
}
}
}
}
//apply rewrite if possible
if (leftScalar || rightScalar) {
Hop root = csb.get_hops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
//replace cast with sum
AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
//modify indexing expression according to loop predicate from-to
//NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
//update indexing size information
if (rowIx)
((IndexingOp) ix).setRowLowerEqualsUpper(false);
else
((IndexingOp) ix).setColLowerEqualsUpper(false);
ix.refreshSizeInformation();
ret = csb;
LOG.debug("Applied vectorizeScalarSumForLoop.");
}
return ret;
}
use of org.apache.sysml.hops.Hop.AggOp in project incubator-systemml by apache.
the class TemplateMultiAgg method constructCplan.
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
//get all root nodes for multi aggregation
MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateType.MultiAggTpl);
ArrayList<Hop> roots = new ArrayList<Hop>();
for (int i = 0; i < 3; i++) if (multiAgg.isPlanRef(i))
roots.add(memo._hopRefs.get(multiAgg.input(i)));
Hop.resetVisitStatus(roots);
//recursively process required cplan outputs
HashSet<Hop> inHops = new HashSet<Hop>();
HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
for (//use celltpl cplan construction
Hop root : //use celltpl cplan construction
roots) super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
Hop.resetVisitStatus(roots);
//reorder inputs (ensure matrices/vectors come first) and prune literals
//note: we order by number of cells and subsequently sparsity to ensure
//that sparse inputs are used as the main input w/o unnecessary conversion
List<Hop> sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator()).collect(Collectors.toList());
//construct template node
ArrayList<CNode> inputs = new ArrayList<CNode>();
for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
ArrayList<CNode> outputs = new ArrayList<CNode>();
ArrayList<AggOp> aggOps = new ArrayList<AggOp>();
for (Hop root : roots) {
CNode node = tmp.get(root.getHopID());
if (//add indexing ops for sideways data inputs
node instanceof CNodeData && ((CNodeData) inputs.get(0)).getHopID() != ((CNodeData) node).getHopID())
node = new CNodeUnary(node, (roots.get(0).getDim2() == 1) ? UnaryType.LOOKUP_R : UnaryType.LOOKUP_RC);
outputs.add(node);
aggOps.add(TemplateUtils.getAggOp(root));
}
CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
tpl.setAggOps(aggOps);
tpl.setRootNodes(roots);
// return cplan instance
return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
Aggregations