use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.
the class RewriteInjectSparkLoopCheckpointing method rewriteStatementBlock.
@Override
public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException {
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
if (!OptimizerUtils.isSparkExecutionMode()) {
// nothing to do here
ret.add(sb);
//return original statement block
return ret;
}
//1) We currently add checkpoint operations without information about the global program structure,
//this assumes that redundant checkpointing is prevented at runtime level (instruction-level)
//2) Also, we do not take size information into account right now. This means that all candidates
//are checkpointed even if they are only used by CP operations.
//block size set by reblock rewrite
int blocksize = status.getBlocksize();
//optimization because otherwise we would prevent remote parfor)
if (//incl parfor
(sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) && (_checkCtx ? !status.isInParforContext() : true)) {
//step 1: determine checkpointing candidates
ArrayList<String> candidates = new ArrayList<String>();
VariableSet read = sb.variablesRead();
VariableSet updated = sb.variablesUpdated();
for (String rvar : read.getVariableNames()) if (!updated.containsVariable(rvar) && read.getVariable(rvar).getDataType() == DataType.MATRIX)
candidates.add(rvar);
//step 2: insert statement block with checkpointing operations
if (//existing candidates
!candidates.isEmpty()) {
StatementBlock sb0 = new StatementBlock();
sb0.setDMLProg(sb.getDMLProg());
sb0.setAllPositions(sb.getFilename(), sb.getBeginLine(), sb.getBeginColumn(), sb.getEndLine(), sb.getEndColumn());
ArrayList<Hop> hops = new ArrayList<Hop>();
VariableSet livein = new VariableSet();
VariableSet liveout = new VariableSet();
for (String var : candidates) {
DataIdentifier dat = read.getVariable(var);
long dim1 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier) dat).getOrigDim1() : dat.getDim1();
long dim2 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier) dat).getOrigDim2() : dat.getDim2();
DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, DataOpTypes.TRANSIENTREAD, dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize);
tread.setRequiresCheckpoint(true);
DataOp twrite = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, tread, DataOpTypes.TRANSIENTWRITE, null);
HopRewriteUtils.setOutputParameters(twrite, dim1, dim2, blocksize, blocksize, dat.getNnz());
hops.add(twrite);
livein.addVariable(var, read.getVariable(var));
liveout.addVariable(var, read.getVariable(var));
}
sb0.set_hops(hops);
sb0.setLiveIn(livein);
sb0.setLiveOut(liveout);
ret.add(sb0);
//maintain rewrite status
status.setInjectedCheckpoints();
}
}
//add original statement block to end
ret.add(sb);
return ret;
}
use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method getFunctionCandidatesForStatisticPropagation.
/////////////////////////////
// GET FUNCTION CANDIDATES
//////
private void getFunctionCandidatesForStatisticPropagation(StatementBlock sb, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) throws HopsException, ParseException {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock sbi : wstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock sbi : istmt.getIfBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
for (StatementBlock sbi : istmt.getElseBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else if (//incl parfor
sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
} else //generic (last-level)
{
ArrayList<Hop> roots = sb.get_hops();
if (//empty statement blocks
roots != null)
for (Hop root : roots) getFunctionCandidatesForStatisticPropagation(sb.getDMLProg(), root, fcandCounts, fcandHops);
}
}
use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.
the class InterProceduralAnalysis method removeCheckpointBeforeUpdate.
/////////////////////////////
// REMOVE UNNECESSARY CHECKPOINTS
//////
private void removeCheckpointBeforeUpdate(DMLProgram dmlp) throws HopsException {
//approach: scan over top-level program (guaranteed to be unconditional),
//collect checkpoints; determine if used before update; remove first checkpoint
//on second checkpoint if update in between and not used before update
HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
for (StatementBlock sb : dmlp.getStatementBlocks()) {
//prune candidates (used before updated)
Set<String> cands = new HashSet<String>(chkpointCand.keySet());
for (String cand : cands) if (sb.variablesRead().containsVariable(cand) && !sb.variablesUpdated().containsVariable(cand)) {
//note: variableRead might include false positives due to meta
//data operations like nrow(X) or operations removed by rewrites
//double check hops on basic blocks; otherwise worst-case
boolean skipRemove = false;
if (sb.get_hops() != null) {
Hop.resetVisitStatus(sb.get_hops());
skipRemove = true;
for (Hop root : sb.get_hops()) skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
}
if (!skipRemove)
chkpointCand.remove(cand);
}
//prune candidates (updated in conditional control flow)
Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand)) {
chkpointCand.remove(cand);
}
} else //prune candidates (updated w/ multiple reads)
{
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) {
Hop.resetVisitStatus(sb.get_hops());
for (Hop root : sb.get_hops()) if (root.getName().equals(cand) && !HopRewriteUtils.rHasSimpleReadChain(root, cand)) {
chkpointCand.remove(cand);
}
}
}
//collect checkpoints and remove unnecessary checkpoints
ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
for (Hop chkpoint : tmp) {
if (chkpointCand.containsKey(chkpoint.getName())) {
chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
}
chkpointCand.put(chkpoint.getName(), chkpoint);
}
}
}
use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.
the class Recompiler method rRecompileProgramBlock.
//////////////////////////////
// private helper functions //
//////////////////////////////
private static void rRecompileProgramBlock(ProgramBlock pb, LocalVariableMap vars, RecompileStatus status, long tid, boolean resetRecompile) throws HopsException, DMLRuntimeException, LopsException, IOException {
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) wpb.getStatementBlock();
//recompile predicate
recompileWhilePredicate(wpb, wsb, vars, status, tid, resetRecompile);
//remove updated scalars because in loop
removeUpdatedScalars(vars, wsb);
//copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
RecompileStatus oldStatus = (RecompileStatus) status.clone();
for (ProgramBlock pb2 : wpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
if (reconcileUpdatedCallVarsLoops(oldVars, vars, wsb) | reconcileUpdatedCallVarsLoops(oldStatus, status, wsb)) {
//second pass with unknowns if required
recompileWhilePredicate(wpb, wsb, vars, status, tid, resetRecompile);
for (ProgramBlock pb2 : wpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
}
removeUpdatedScalars(vars, wsb);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) ipb.getStatementBlock();
//recompile predicate
recompileIfPredicate(ipb, isb, vars, status, tid, resetRecompile);
//copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
LocalVariableMap varsElse = (LocalVariableMap) vars.clone();
RecompileStatus oldStatus = (RecompileStatus) status.clone();
RecompileStatus statusElse = (RecompileStatus) status.clone();
for (ProgramBlock pb2 : ipb.getChildBlocksIfBody()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
for (ProgramBlock pb2 : ipb.getChildBlocksElseBody()) rRecompileProgramBlock(pb2, varsElse, statusElse, tid, resetRecompile);
reconcileUpdatedCallVarsIf(oldVars, vars, varsElse, isb);
reconcileUpdatedCallVarsIf(oldStatus, status, statusElse, isb);
removeUpdatedScalars(vars, ipb.getStatementBlock());
} else if (//includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) fpb.getStatementBlock();
//recompile predicates
recompileForPredicates(fpb, fsb, vars, status, tid, resetRecompile);
//remove updated scalars because in loop
removeUpdatedScalars(vars, fpb.getStatementBlock());
//copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap) vars.clone();
RecompileStatus oldStatus = (RecompileStatus) status.clone();
for (ProgramBlock pb2 : fpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
if (reconcileUpdatedCallVarsLoops(oldVars, vars, fsb) | reconcileUpdatedCallVarsLoops(oldStatus, status, fsb)) {
//second pass with unknowns if required
recompileForPredicates(fpb, fsb, vars, status, tid, resetRecompile);
for (ProgramBlock pb2 : fpb.getChildBlocks()) rRecompileProgramBlock(pb2, vars, status, tid, resetRecompile);
}
removeUpdatedScalars(vars, fpb.getStatementBlock());
} else if (//includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
pb instanceof FunctionProgramBlock) {
//do nothing
} else {
StatementBlock sb = pb.getStatementBlock();
ArrayList<Instruction> tmp = pb.getInstructions();
if (//recompile all for stats propagation and recompile flags
sb != null) //&& Recompiler.requiresRecompilation( sb.get_hops() )
/*&& !Recompiler.containsNonRecompileInstructions(tmp)*/
{
tmp = Recompiler.recompileHopsDag(sb, sb.get_hops(), vars, status, true, false, tid);
pb.setInstructions(tmp);
//propagate stats across hops (should be executed on clone of vars)
Recompiler.extractDAGOutputStatistics(sb.get_hops(), vars);
//reset recompilation flags (w/ special handling functions)
if (ParForProgramBlock.RESET_RECOMPILATION_FLAGs && !containsRootFunctionOp(sb.get_hops()) && resetRecompile) {
Hop.resetRecompilationFlag(sb.get_hops(), ExecType.CP);
sb.updateRecompilationFlag();
}
}
}
}
use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.
the class Recompiler method rRecompileProgramBlock2Forced.
private static void rRecompileProgramBlock2Forced(ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et) throws HopsException, DMLRuntimeException, LopsException, IOException {
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock pbTmp = (WhileProgramBlock) pb;
WhileStatementBlock sbTmp = (WhileStatementBlock) pbTmp.getStatementBlock();
//recompile predicate
if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
//recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock pbTmp = (IfProgramBlock) pb;
IfStatementBlock sbTmp = (IfStatementBlock) pbTmp.getStatementBlock();
//recompile predicate
if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
//recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocksIfBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
for (ProgramBlock pb2 : pbTmp.getChildBlocksElseBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (//includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock pbTmp = (ForProgramBlock) pb;
ForStatementBlock sbTmp = (ForStatementBlock) pbTmp.getStatementBlock();
//recompile predicate
if (sbTmp != null && sbTmp.getFromHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getFromInstructions(), true, true)))
pbTmp.setFromInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getFromHops(), tid, et));
if (sbTmp != null && sbTmp.getToHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getToInstructions(), true, true)))
pbTmp.setToInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getToHops(), tid, et));
if (sbTmp != null && sbTmp.getIncrementHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getIncrementInstructions(), true, true)))
pbTmp.setIncrementInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getIncrementHops(), tid, et));
//recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (//includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
pb instanceof FunctionProgramBlock) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
for (ProgramBlock pb2 : tmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else {
StatementBlock sb = pb.getStatementBlock();
//would be invalid with permutation matrix mult across multiple dags)
if (sb != null) {
ArrayList<Instruction> tmp = pb.getInstructions();
tmp = Recompiler.recompileHopsDag2Forced(sb, sb.get_hops(), tid, et);
pb.setInstructions(tmp);
}
//recompile functions
if (OptTreeConverter.containsFunctionCallInstruction(pb)) {
ArrayList<Instruction> tmp = pb.getInstructions();
for (Instruction inst : tmp) if (inst instanceof FunctionCallCPInstruction) {
FunctionCallCPInstruction func = (FunctionCallCPInstruction) inst;
String fname = func.getFunctionName();
String fnamespace = func.getNamespace();
String fKey = DMLProgram.constructFunctionKey(fnamespace, fname);
if (//memoization for multiple calls, recursion
!fnStack.contains(fKey)) {
fnStack.add(fKey);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fnamespace, fname);
//recompile chains of functions
rRecompileProgramBlock2Forced(fpb, tid, fnStack, et);
}
}
}
}
}
Aggregations