use of org.apache.sysml.parser.FunctionStatementBlock in project systemml by apache.
the class InterProceduralAnalysis method isUnarySizePreservingFunction.
private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) {
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
// check unary functions over matrices
boolean ret = (fstmt.getInputParams().size() == 1 && fstmt.getInputParams().get(0).getDataType() == DataType.MATRIX && fstmt.getOutputParams().size() == 1 && fstmt.getOutputParams().get(0).getDataType() == DataType.MATRIX);
// check size-preserving characteristic
if (ret) {
FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph, false);
HashSet<String> fnStack = new HashSet<>();
LocalVariableMap callVars = new LocalVariableMap();
// populate input
MatrixObject mo = createOutputMatrix(7777, 3333, -1);
callVars.put(fstmt.getInputParams().get(0).getName(), mo);
// propagate statistics
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
// compare output
MatrixObject mo2 = (MatrixObject) callVars.get(fstmt.getOutputParams().get(0).getName());
ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns();
// reset function
mo.getMatrixCharacteristics().setDimension(-1, -1);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
}
return ret;
}
use of org.apache.sysml.parser.FunctionStatementBlock in project systemml by apache.
the class ProgramRewriter method rRewriteStatementBlockHopDAGs.
public void rRewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) {
// ensure robustness for calls from outside
if (state == null)
state = new ProgramRewriteStatus();
if (current instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) current;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sb : fstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else if (current instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) current;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
for (StatementBlock sb : wstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else if (current instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement) isb.getStatement(0);
isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
for (StatementBlock sb : istmt.getIfBody()) rRewriteStatementBlockHopDAGs(sb, state);
for (StatementBlock sb : istmt.getElseBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else if (// incl parfor
current instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) current;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fsb.setFromHops(rewriteHopDAG(fsb.getFromHops(), state));
fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
for (StatementBlock sb : fstmt.getBody()) rRewriteStatementBlockHopDAGs(sb, state);
} else // generic (last-level)
{
current.setHops(rewriteHopDAG(current.getHops(), state));
}
}
use of org.apache.sysml.parser.FunctionStatementBlock in project systemml by apache.
the class ProgramRewriter method rewriteProgramHopDAGs.
public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) {
ProgramRewriteStatus state = new ProgramRewriteStatus();
// for each namespace, handle function statement blocks
for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
rRewriteStatementBlockHopDAGs(fsblock, state);
if (!_sbRuleSet.isEmpty())
rRewriteStatementBlock(fsblock, state, splitDags);
}
// handle regular statement blocks in "main" method
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock current = dmlp.getStatementBlock(i);
rRewriteStatementBlockHopDAGs(current, state);
}
if (!_sbRuleSet.isEmpty())
dmlp.setStatementBlocks(rRewriteStatementBlocks(dmlp.getStatementBlocks(), state, splitDags));
return state;
}
Aggregations