use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class SpoofCompiler method generateCodeFromStatementBlock.
public static void generateCodeFromStatementBlock(StatementBlock current) {
if (current instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) current;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
} else if (current instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) current;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wsb.setPredicateHops(optimize(wsb.getPredicateHops(), false));
for (StatementBlock sb : wstmt.getBody()) generateCodeFromStatementBlock(sb);
} else if (current instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement) isb.getStatement(0);
isb.setPredicateHops(optimize(isb.getPredicateHops(), false));
for (StatementBlock sb : istmt.getIfBody()) generateCodeFromStatementBlock(sb);
for (StatementBlock sb : istmt.getElseBody()) generateCodeFromStatementBlock(sb);
} else if (// incl parfor
current instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) current;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fsb.setFromHops(optimize(fsb.getFromHops(), false));
fsb.setToHops(optimize(fsb.getToHops(), false));
fsb.setIncrementHops(optimize(fsb.getIncrementHops(), false));
for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
} else // generic (last-level)
{
current.setHops(generateCodeFromHopDAGs(current.getHops()));
current.updateRecompilationFlag();
}
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class RewriteCompressedReblock method rAnalyzeProgram.
private static void rAnalyzeProgram(StatementBlock sb, ProbeStatus status) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock csb : wstmt.getBody()) rAnalyzeProgram(csb, status);
if (wsb.variablesRead().containsAnyName(status.compMtx))
status.usedInLoop = true;
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock csb : istmt.getIfBody()) rAnalyzeProgram(csb, status);
for (StatementBlock csb : istmt.getElseBody()) rAnalyzeProgram(csb, status);
if (isb.variablesUpdated().containsAnyName(status.compMtx))
status.condUpdate = true;
} else if (sb instanceof ForStatementBlock) {
// incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
if (fsb.variablesRead().containsAnyName(status.compMtx))
status.usedInLoop = true;
} else if (sb.getHops() != null) {
// generic (last-level)
ArrayList<Hop> roots = sb.getHops();
Hop.resetVisitStatus(roots);
// process entire HOP DAG starting from the roots
for (Hop root : roots) rAnalyzeHopDag(root, status);
// remove temporary variables
status.compMtx.removeIf(n -> n.startsWith(TMP_PREFIX));
Hop.resetVisitStatus(roots);
}
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class RewriteCompressedReblock method rAnalyzeHopDag.
private static void rAnalyzeHopDag(Hop current, ProbeStatus status) {
if (current.isVisited())
return;
// process children recursively
for (Hop input : current.getInput()) rAnalyzeHopDag(input, status);
// handle source persistent read
if (current.getHopID() == status.startHopID) {
status.compMtx.add(getTmpName(current));
status.foundStart = true;
}
// a) handle function calls
if (current instanceof FunctionOp && hasCompressedInput(current, status)) {
// TODO handle of functions in a more fine-grained manner
// to cover special cases multiple calls where compressed
// inputs might occur for different input parameters
FunctionOp fop = (FunctionOp) current;
String fkey = fop.getFunctionKey();
if (!status.procFn.contains(fkey)) {
// memoization to avoid redundant analysis and recursive calls
status.procFn.add(fkey);
// map inputs to function inputs
FunctionStatementBlock fsb = status.prog.getFunctionStatementBlock(fkey);
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
ProbeStatus status2 = new ProbeStatus(status);
for (int i = 0; i < fop.getInput().size(); i++) if (status.compMtx.contains(getTmpName(fop.getInput().get(i))))
status2.compMtx.add(fstmt.getInputParams().get(i).getName());
// analyze function and merge meta info
rAnalyzeProgram(fsb, status2);
status.foundStart |= status2.foundStart;
status.usedInLoop |= status2.usedInLoop;
status.condUpdate |= status2.condUpdate;
status.nonApplicable |= status2.nonApplicable;
// map function outputs to outputs
String[] outputs = fop.getOutputVariableNames();
for (int i = 0; i < outputs.length; i++) if (status2.compMtx.contains(fstmt.getOutputParams().get(i).getName()))
status.compMtx.add(outputs[i]);
}
} else // b) handle transient reads and writes (name mapping)
if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) && status.compMtx.contains(getTmpName(current.getInput().get(0))))
status.compMtx.add(current.getName());
else if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) && status.compMtx.contains(current.getName()))
status.compMtx.add(getTmpName(current));
else // c) handle applicable operations
if (hasCompressedInput(current, status)) {
// valid with uncompressed outputs
boolean compUCOut = (// tsmm
current instanceof AggBinaryOp && current.getDim2() <= current.getColsInBlock() && ((AggBinaryOp) current).checkTransposeSelf() == MMTSJType.LEFT) || // mvmm
(current instanceof AggBinaryOp && (current.getDim1() == 1 || current.getDim2() == 1)) || (HopRewriteUtils.isTransposeOperation(current) && current.getParent().size() == 1 && current.getParent().get(0) instanceof AggBinaryOp && (current.getParent().get(0).getDim1() == 1 || current.getParent().get(0).getDim2() == 1)) || HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX);
// valid with compressed outputs
boolean compCOut = HopRewriteUtils.isBinaryMatrixScalarOperation(current) || HopRewriteUtils.isBinary(current, OpOp2.CBIND);
boolean metaOp = HopRewriteUtils.isUnary(current, OpOp1.NROW, OpOp1.NCOL);
status.nonApplicable |= !(compUCOut || compCOut || metaOp);
if (compCOut)
status.compMtx.add(getTmpName(current));
}
current.setVisited();
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.
// ///////////////////////////
// INTRA-PROCEDURE ANALYSIS
// ////
private void propagateStatisticsAcrossBlock(StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, wsb);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
// second pass if required
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
}
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack);
callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else if (// incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
// old stats into predicate
propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, fsb);
// check and propagate stats into body
LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
} else // generic (last-level)
{
// remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
// old stats in, new stats out if updated
ArrayList<Hop> roots = sb.getHops();
DMLProgram prog = sb.getDMLProg();
// replace scalar reads with literals
Hop.resetVisitStatus(roots);
propagateScalarsAcrossDAG(roots, callVars);
// refresh stats across dag
Hop.resetVisitStatus(roots);
propagateStatisticsAcrossDAG(roots, callVars);
// propagate stats into function calls
Hop.resetVisitStatus(roots);
propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack);
}
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class InterProceduralAnalysis method isUnarySizePreservingFunction.
private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) {
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
// check unary functions over matrices
boolean ret = (fstmt.getInputParams().size() == 1 && fstmt.getInputParams().get(0).getDataType() == DataType.MATRIX && fstmt.getOutputParams().size() == 1 && fstmt.getOutputParams().get(0).getDataType() == DataType.MATRIX);
// check size-preserving characteristic
if (ret) {
FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph, false);
HashSet<String> fnStack = new HashSet<>();
LocalVariableMap callVars = new LocalVariableMap();
// populate input
MatrixObject mo = createOutputMatrix(7777, 3333, -1);
callVars.put(fstmt.getInputParams().get(0).getName(), mo);
// propagate statistics
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
// compare output
MatrixObject mo2 = (MatrixObject) callVars.get(fstmt.getOutputParams().get(0).getName());
ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns();
// reset function
mo.getMatrixCharacteristics().setDimension(-1, -1);
for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
}
return ret;
}
Aggregations