use of org.apache.sysml.runtime.controlprogram.ForProgramBlock in project incubator-systemml by apache.
the class OptimizerRuleBased method removeUnnecessaryParFor.
protected int removeUnnecessaryParFor(OptNode n) {
int count = 0;
if (!n.isLeaf()) {
for (OptNode sub : n.getChilds()) {
if (sub.getNodeType() == NodeType.PARFOR && sub.getK() == 1) {
long id = sub.getID();
Object[] progobj = OptTreeConverter.getAbstractPlanMapping().getMappedProg(id);
ParForStatementBlock pfsb = (ParForStatementBlock) progobj[0];
ParForProgramBlock pfpb = (ParForProgramBlock) progobj[1];
// create for pb as replacement
Program prog = pfpb.getProgram();
ForProgramBlock fpb = ProgramConverter.createShallowCopyForProgramBlock(pfpb, prog);
// replace parfor with for, and update objectmapping
OptTreeConverter.replaceProgramBlock(n, sub, pfpb, fpb, false);
// update link to statement block
fpb.setStatementBlock(pfsb);
// update node
sub.setNodeType(NodeType.FOR);
sub.setK(1);
count++;
}
count += removeUnnecessaryParFor(sub);
}
}
return count;
}
use of org.apache.sysml.runtime.controlprogram.ForProgramBlock in project incubator-systemml by apache.
the class ProgramRecompiler method rFindAndRecompileIndexingHOP.
/**
* NOTE: if force is set, we set and recompile the respective indexing hops;
* otherwise, we release the forced exec type and recompile again. Hence,
* any changes can be exactly reverted with the same access behavior.
*
* @param sb statement block
* @param pb program block
* @param var variable
* @param ec execution context
* @param force if true, set and recompile the respective indexing hops
*/
public static void rFindAndRecompileIndexingHOP(StatementBlock sb, ProgramBlock pb, String var, ExecutionContext ec, boolean force) {
if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
// process if condition
if (isb.getPredicateHops() != null)
ipb.setPredicate(rFindAndRecompileIndexingHOP(isb.getPredicateHops(), ipb.getPredicate(), var, ec, force));
// process if branch
int len = is.getIfBody().size();
for (int i = 0; i < ipb.getChildBlocksIfBody().size() && i < len; i++) {
ProgramBlock lpb = ipb.getChildBlocksIfBody().get(i);
StatementBlock lsb = is.getIfBody().get(i);
rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
}
// process else branch
if (ipb.getChildBlocksElseBody() != null) {
int len2 = is.getElseBody().size();
for (int i = 0; i < ipb.getChildBlocksElseBody().size() && i < len2; i++) {
ProgramBlock lpb = ipb.getChildBlocksElseBody().get(i);
StatementBlock lsb = is.getElseBody().get(i);
rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
}
}
} else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
// process while condition
if (wsb.getPredicateHops() != null)
wpb.setPredicate(rFindAndRecompileIndexingHOP(wsb.getPredicateHops(), wpb.getPredicate(), var, ec, force));
// process body
// robustness for potentially added problem blocks
int len = ws.getBody().size();
for (int i = 0; i < wpb.getChildBlocks().size() && i < len; i++) {
ProgramBlock lpb = wpb.getChildBlocks().get(i);
StatementBlock lsb = ws.getBody().get(i);
rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
}
} else if (// for or parfor
pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
if (fsb.getFromHops() != null)
fpb.setFromInstructions(rFindAndRecompileIndexingHOP(fsb.getFromHops(), fpb.getFromInstructions(), var, ec, force));
if (fsb.getToHops() != null)
fpb.setToInstructions(rFindAndRecompileIndexingHOP(fsb.getToHops(), fpb.getToInstructions(), var, ec, force));
if (fsb.getIncrementHops() != null)
fpb.setIncrementInstructions(rFindAndRecompileIndexingHOP(fsb.getIncrementHops(), fpb.getIncrementInstructions(), var, ec, force));
// process body
// robustness for potentially added problem blocks
int len = fs.getBody().size();
for (int i = 0; i < fpb.getChildBlocks().size() && i < len; i++) {
ProgramBlock lpb = fpb.getChildBlocks().get(i);
StatementBlock lsb = fs.getBody().get(i);
rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
}
} else // last level program block
{
try {
// process actual hops
boolean ret = false;
Hop.resetVisitStatus(sb.getHops());
if (force) {
// set forced execution type
for (Hop h : sb.getHops()) ret |= rFindAndSetCPIndexingHOP(h, var);
} else {
// release forced execution type
for (Hop h : sb.getHops()) ret |= rFindAndReleaseIndexingHOP(h, var);
}
// recompilation on-demand
if (ret) {
// construct new instructions
ArrayList<Instruction> newInst = Recompiler.recompileHopsDag(sb, sb.getHops(), ec.getVariables(), null, true, false, 0);
pb.setInstructions(newInst);
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
}
}
use of org.apache.sysml.runtime.controlprogram.ForProgramBlock in project incubator-systemml by apache.
the class ProgramConverter method createShallowCopyForProgramBlock.
public static ForProgramBlock createShallowCopyForProgramBlock(ForProgramBlock fpb, Program prog) {
ForProgramBlock tmpPB = new ForProgramBlock(prog, fpb.getIterVar());
tmpPB.setFromInstructions(fpb.getFromInstructions());
tmpPB.setToInstructions(fpb.getToInstructions());
tmpPB.setIncrementInstructions(fpb.getIncrementInstructions());
tmpPB.setExitInstructions(fpb.getExitInstructions());
tmpPB.setChildBlocks(fpb.getChildBlocks());
return tmpPB;
}
use of org.apache.sysml.runtime.controlprogram.ForProgramBlock in project incubator-systemml by apache.
the class CostEstimator method rGetTimeEstimate.
private double rGetTimeEstimate(ProgramBlock pb, HashMap<String, VarStats> stats, HashSet<String> memoFunc, boolean recursive) {
double ret = 0;
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock tmp = (WhileProgramBlock) pb;
if (recursive)
for (ProgramBlock pb2 : tmp.getChildBlocks()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret *= DEFAULT_NUMITER;
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock tmp = (IfProgramBlock) pb;
if (recursive) {
for (ProgramBlock pb2 : tmp.getChildBlocksIfBody()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
if (tmp.getChildBlocksElseBody() != null)
for (ProgramBlock pb2 : tmp.getChildBlocksElseBody()) {
ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
// weighted sum
ret /= 2;
}
}
} else if (// includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock tmp = (ForProgramBlock) pb;
if (recursive)
for (ProgramBlock pb2 : tmp.getChildBlocks()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
ret *= getNumIterations(stats, tmp);
} else if (pb instanceof FunctionProgramBlock && // see generic
!(pb instanceof ExternalFunctionProgramBlock)) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
if (recursive)
for (ProgramBlock pb2 : tmp.getChildBlocks()) ret += rGetTimeEstimate(pb2, stats, memoFunc, recursive);
} else {
ArrayList<Instruction> tmp = pb.getInstructions();
for (Instruction inst : tmp) {
if (// CP
inst instanceof CPInstruction) {
// obtain stats from createvar, cpvar, rmvar, rand
maintainCPInstVariableStatistics((CPInstruction) inst, stats);
// extract statistics (instruction-specific)
Object[] o = extractCPInstStatistics(inst, stats);
VarStats[] vs = (VarStats[]) o[0];
String[] attr = (String[]) o[1];
// if(LOG.isDebugEnabled())
// LOG.debug(inst);
// call time estimation for inst
ret += getCPInstTimeEstimate(inst, vs, attr);
if (// functions
inst instanceof FunctionCallCPInstruction) {
FunctionCallCPInstruction finst = (FunctionCallCPInstruction) inst;
String fkey = DMLProgram.constructFunctionKey(finst.getNamespace(), finst.getFunctionName());
// awareness of recursive functions, missing program
if (!memoFunc.contains(fkey) && pb.getProgram() != null) {
if (LOG.isDebugEnabled())
LOG.debug("Begin Function " + fkey);
memoFunc.add(fkey);
Program prog = pb.getProgram();
FunctionProgramBlock fpb = prog.getFunctionProgramBlock(finst.getNamespace(), finst.getFunctionName());
ret += rGetTimeEstimate(fpb, stats, memoFunc, recursive);
memoFunc.remove(fkey);
if (LOG.isDebugEnabled())
LOG.debug("End Function " + fkey);
}
}
} else if (// MR
inst instanceof MRJobInstruction) {
// obtain stats for job
maintainMRJobInstVariableStatistics(inst, stats);
// extract input statistics
Object[] o = extractMRJobInstStatistics(inst, stats);
VarStats[] vs = (VarStats[]) o[0];
if (LOG.isDebugEnabled())
LOG.debug("Begin MRJob type=" + ((MRJobInstruction) inst).getJobType());
// call time estimation for complex MR inst
ret += getMRJobInstTimeEstimate(inst, vs, null);
if (LOG.isDebugEnabled())
LOG.debug("End MRJob");
// cleanup stats for job
cleanupMRJobVariableStatistics(inst, stats);
}
}
}
return ret;
}
use of org.apache.sysml.runtime.controlprogram.ForProgramBlock in project incubator-systemml by apache.
the class Recompiler method rRecompileProgramBlock2Forced.
private static void rRecompileProgramBlock2Forced(ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et) {
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock pbTmp = (WhileProgramBlock) pb;
WhileStatementBlock sbTmp = (WhileStatementBlock) pbTmp.getStatementBlock();
// recompile predicate
if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
// recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock pbTmp = (IfProgramBlock) pb;
IfStatementBlock sbTmp = (IfStatementBlock) pbTmp.getStatementBlock();
// recompile predicate
if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
// recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocksIfBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
for (ProgramBlock pb2 : pbTmp.getChildBlocksElseBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (// includes ParFORProgramBlock
pb instanceof ForProgramBlock) {
ForProgramBlock pbTmp = (ForProgramBlock) pb;
ForStatementBlock sbTmp = (ForStatementBlock) pbTmp.getStatementBlock();
// recompile predicate
if (sbTmp != null && sbTmp.getFromHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getFromInstructions(), true, true)))
pbTmp.setFromInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getFromHops(), tid, et));
if (sbTmp != null && sbTmp.getToHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getToInstructions(), true, true)))
pbTmp.setToInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getToHops(), tid, et));
if (sbTmp != null && sbTmp.getIncrementHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getIncrementInstructions(), true, true)))
pbTmp.setIncrementInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getIncrementHops(), tid, et));
// recompile body
for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
pb instanceof FunctionProgramBlock) {
FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
for (ProgramBlock pb2 : tmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
} else {
StatementBlock sb = pb.getStatementBlock();
// would be invalid with permutation matrix mult across multiple dags)
if (sb != null) {
ArrayList<Instruction> tmp = pb.getInstructions();
tmp = Recompiler.recompileHopsDag2Forced(sb, sb.getHops(), tid, et);
pb.setInstructions(tmp);
}
// recompile functions
if (OptTreeConverter.containsFunctionCallInstruction(pb)) {
ArrayList<Instruction> tmp = pb.getInstructions();
for (Instruction inst : tmp) if (inst instanceof FunctionCallCPInstruction) {
FunctionCallCPInstruction func = (FunctionCallCPInstruction) inst;
String fname = func.getFunctionName();
String fnamespace = func.getNamespace();
String fKey = DMLProgram.constructFunctionKey(fnamespace, fname);
if (// memoization for multiple calls, recursion
!fnStack.contains(fKey)) {
fnStack.add(fKey);
FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fnamespace, fname);
// recompile chains of functions
rRecompileProgramBlock2Forced(fpb, tid, fnStack, et);
}
}
}
}
}
Aggregations