use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class TemplateUtils method getRowTemplateMatrixInput.
public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) {
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
long ret = -1;
for (int i = 0; ret < 0 && i < current.getInput().size(); i++) {
Hop input = current.getInput().get(i);
if (me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateType.ROW))
ret = getRowTemplateMatrixInput(input, memo);
else if (!me.isPlanRef(i) && isMatrix(input))
ret = input.getHopID();
}
return ret;
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class FunctionCallSizeInfo method constructFunctionCallSizeInfo.
private void constructFunctionCallSizeInfo() {
// step 1: determine function candidates by evaluating all function calls
for (String fkey : _fgraph.getReachableFunctions()) {
List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
// condition 1: function called just once
if (flist.size() == 1) {
_fcand.add(fkey);
} else // condition 2: check for consistent input sizes
if (InterProceduralAnalysis.ALLOW_MULTIPLE_FUNCTION_CALLS) {
// compare input matrix characteristics of first against all other calls
FunctionOp first = flist.get(0);
boolean consistent = true;
for (int i = 1; i < flist.size(); i++) {
FunctionOp other = flist.get(i);
for (int j = 0; j < first.getInput().size(); j++) {
Hop h1 = first.getInput().get(j);
Hop h2 = other.getInput().get(j);
// check matrix and scalar sizes (if known dims, nnz known/unknown,
// safeness of nnz propagation, determined later per input)
consistent &= (h1.dimsKnown() && h2.dimsKnown() && h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2() && h1.getNnz() == h2.getNnz());
// check literal values (equi value)
if (h1 instanceof LiteralOp) {
consistent &= (h2 instanceof LiteralOp && HopRewriteUtils.isEqualValue((LiteralOp) h1, (LiteralOp) h2));
}
}
}
if (consistent)
_fcand.add(fkey);
}
}
// (considered for valid functions only)
for (String fkey : _fcand) {
FunctionOp first = _fgraph.getFunctionCalls(fkey).get(0);
HashSet<Integer> tmp = new HashSet<>();
for (int j = 0; j < first.getInput().size(); j++) {
// if nnz known it is safe to propagate those nnz because for multiple calls
// we checked of equivalence and hence all calls have the same nnz
Hop input = first.getInput().get(0);
if (input.getNnz() >= 0)
tmp.add(j);
}
_fcandSafeNNZ.put(fkey, tmp);
}
// (considered for all functions)
for (String fkey : _fgraph.getReachableFunctions()) {
List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey);
FunctionOp first = flist.get(0);
// initialize w/ all literals of first call
HashSet<Integer> tmp = new HashSet<>();
for (int j = 0; j < first.getInput().size(); j++) if (first.getInput().get(j) instanceof LiteralOp)
tmp.add(j);
// check consistency across all function calls
for (int i = 1; i < flist.size(); i++) {
FunctionOp other = flist.get(i);
for (int j = 0; j < first.getInput().size(); j++) if (tmp.contains(j)) {
Hop h1 = first.getInput().get(j);
Hop h2 = other.getInput().get(j);
if (!(h2 instanceof LiteralOp && HopRewriteUtils.isEqualValue((LiteralOp) h1, (LiteralOp) h2)))
tmp.remove(j);
}
}
_fSafeLiterals.put(fkey, tmp);
}
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class IPAPassInlineFunctions method rewriteProgram.
@Override
public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
for (String fkey : fgraph.getReachableFunctions()) {
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
if (fstmt.getBody().size() == 1 && HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0)) && !containsFunctionOp(fstmt.getBody().get(0).getHops()) && (fgraph.getFunctionCalls(fkey).size() == 1 || countOperators(fstmt.getBody().get(0).getHops()) <= InterProceduralAnalysis.INLINING_MAX_NUM_OPS)) {
if (LOG.isDebugEnabled())
LOG.debug("IPA: Inline function '" + fkey + "'");
// replace all relevant function calls
ArrayList<Hop> hops = fstmt.getBody().get(0).getHops();
List<FunctionOp> fcalls = fgraph.getFunctionCalls(fkey);
List<StatementBlock> fcallsSB = fgraph.getFunctionCallsSB(fkey);
for (int i = 0; i < fcalls.size(); i++) {
FunctionOp op = fcalls.get(i);
// step 0: robustness for special cases
if (op.getInput().size() != fstmt.getInputParams().size() || op.getOutputVariableNames().length != fstmt.getOutputParams().size())
continue;
// step 1: deep copy hop dag
ArrayList<Hop> hops2 = Recompiler.deepCopyHopsDag(hops);
// step 2: replace inputs
HashMap<String, Hop> inMap = new HashMap<>();
for (int j = 0; j < op.getInput().size(); j++) inMap.put(fstmt.getInputParams().get(j).getName(), op.getInput().get(j));
replaceTransientReads(hops2, inMap);
// step 3: replace outputs
HashMap<String, String> outMap = new HashMap<>();
String[] opOutputs = op.getOutputVariableNames();
for (int j = 0; j < opOutputs.length; j++) outMap.put(fstmt.getOutputParams().get(j).getName(), opOutputs[j]);
for (int j = 0; j < hops2.size(); j++) {
Hop out = hops2.get(j);
if (HopRewriteUtils.isData(out, DataOpTypes.TRANSIENTWRITE))
out.setName(outMap.get(out.getName()));
}
fcallsSB.get(i).getHops().remove(op);
fcallsSB.get(i).getHops().addAll(hops2);
}
// update the function call graph to avoid repeated inlining
// (and thus op replication) on repeated IPA calls
fgraph.removeFunctionCalls(fkey);
}
}
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class IPAPassInlineFunctions method replaceTransientReads.
private static void replaceTransientReads(ArrayList<Hop> hops, HashMap<String, Hop> inMap) {
Hop.resetVisitStatus(hops);
for (Hop hop : hops) rReplaceTransientReads(hop, inMap);
Hop.resetVisitStatus(hops);
}
use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.
the class IPAPassInlineFunctions method rCountOperators.
private static int rCountOperators(Hop current) {
if (current.isVisited())
return 0;
int count = !(current instanceof DataOp || current instanceof LiteralOp) ? 1 : 0;
for (Hop c : current.getInput()) count += rCountOperators(c);
current.setVisited();
return count;
}
Aggregations