use of org.apache.sysml.hops.codegen.cplan.CNode in project incubator-systemml by apache.
the class CPlanComparisonTest method testNotEqualBinaryDAG1.
@Test
public void testNotEqualBinaryDAG1() {
CNode c1 = createCNodeData(DataType.MATRIX);
CNode c2 = createCNodeData(DataType.MATRIX);
CNode c3 = createCNodeData(DataType.SCALAR);
// DAG 1a: (c1*c2)*c3
CNode b1a = new CNodeBinary(c1, c2, BinType.MULT);
CNode b2a = new CNodeBinary(b1a, c3, BinType.MULT);
// DAG 1b: (c1*c2)*c1
CNode b1b = new CNodeBinary(c1, c2, BinType.MULT);
CNode b2b = new CNodeBinary(b1b, c1, BinType.MULT);
Assert.assertNotEquals(b2a, b2b);
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class TemplateRow method constructCplan.
@Override
public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
// recursively process required cplan output
HashSet<Hop> inHops = new HashSet<>();
HashMap<String, Hop> inHops2 = new HashMap<>();
HashMap<Long, CNode> tmp = new HashMap<>();
hop.resetVisitStatus();
rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
hop.resetVisitStatus();
// reorder inputs (ensure matrix is first input, and other inputs ordered by size)
Hop[] sinHops = inHops.stream().filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral())).sorted(new HopInputComparator(inHops2.get("X"), inHops2.get("B1"))).toArray(Hop[]::new);
// robustness special cases
inHops2.putIfAbsent("X", sinHops[0]);
// construct template node
ArrayList<CNode> inputs = new ArrayList<>();
for (Hop in : sinHops) inputs.add(tmp.get(in.getHopID()));
CNode output = tmp.get(hop.getHopID());
CNodeRow tpl = new CNodeRow(inputs, output);
tpl.setRowType(TemplateUtils.getRowType(hop, inHops2.get("X"), inHops2.get("B1")));
long n2 = tpl.getRowType() == RowType.COL_AGG_B1 ? hop.getDim1() : hop.getDim2();
if (tpl.getRowType().isConstDim2(n2))
tpl.setConstDim2(n2);
tpl.setNumVectorIntermediates(TemplateUtils.determineMinVectorIntermediates(output));
tpl.getOutput().resetVisitStatus();
tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops[0].getHopID());
tpl.setBeginLine(hop.getBeginLine());
// return cplan instance
return new Pair<>(sinHops, tpl);
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class TemplateUtils method countVectorIntermediates.
public static int countVectorIntermediates(CNode node) {
if (node.isVisited())
return 0;
node.setVisited();
// compute vector requirements over all inputs
int ret = 0;
for (CNode c : node.getInput()) ret += countVectorIntermediates(c);
// compute vector requirements of current node
int cntBin = (node instanceof CNodeBinary && ((CNodeBinary) node).getType().isVectorPrimitive() && !((CNodeBinary) node).getType().name().endsWith("_ADD")) ? 1 : 0;
int cntUn = (node instanceof CNodeUnary && ((CNodeUnary) node).getType().isVectorScalarPrimitive()) ? 1 : 0;
int cntTn = (node instanceof CNodeTernary && ((CNodeTernary) node).getType().isVectorPrimitive()) ? 1 : 0;
return ret + cntBin + cntUn + cntTn;
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class TemplateUtils method getMaxVectorIntermediates.
public static int getMaxVectorIntermediates(CNode node) {
if (node.isVisited())
return 0;
int max = 0;
for (CNode input : node.getInput()) max = Math.max(max, getMaxVectorIntermediates(input));
max = Math.max(max, (node instanceof CNodeTernary && ((CNodeTernary) node).getType().isVectorPrimitive()) ? 1 : 0);
max = Math.max(max, (node instanceof CNodeBinary) ? (((CNodeBinary) node).getType().isVectorVectorPrimitive() ? 3 : ((CNodeBinary) node).getType().isVectorScalarPrimitive() ? 2 : ((CNodeBinary) node).getType().isVectorMatrixPrimitive() ? 1 : 0) : 0);
max = Math.max(max, (node instanceof CNodeUnary && ((CNodeUnary) node).getType().isVectorScalarPrimitive()) ? 2 : 0);
node.setVisited();
return max;
}
use of org.apache.sysml.hops.codegen.cplan.CNode in project systemml by apache.
the class CPlanCSERewriter method rEliminateCommonSubexpression.
private void rEliminateCommonSubexpression(CNode current, HashMap<CNode, CNode> cseSet) {
// avoid redundant re-evaluation
if (current.isVisited())
return;
// replace input with existing common subexpression
for (int i = 0; i < current.getInput().size(); i++) {
CNode input = current.getInput().get(i);
if (cseSet.containsKey(input))
current.getInput().set(i, cseSet.get(input));
}
// process inputs recursively
for (CNode input : current.getInput()) rEliminateCommonSubexpression(input, cseSet);
// process node itself
cseSet.put(current, current);
current.setVisited();
}
Aggregations