use of org.apache.sysml.parser.IfStatementBlock in project incubator-systemml by apache.
the class IPAPassRemoveUnnecessaryCheckpoints method moveCheckpointAfterUpdate.
private static void moveCheckpointAfterUpdate(DMLProgram dmlp) {
// approach: scan over top-level program (guaranteed to be unconditional),
// collect checkpoints; determine if used before update; move first checkpoint
// after update if not used before update (best effort move which often avoids
// the second checkpoint on loops even though used in between)
HashMap<String, Hop> chkpointCand = new HashMap<>();
for (StatementBlock sb : dmlp.getStatementBlocks()) {
// prune candidates (used before updated)
Set<String> cands = new HashSet<>(chkpointCand.keySet());
for (String cand : cands) if (sb.variablesRead().containsVariable(cand) && !sb.variablesUpdated().containsVariable(cand)) {
// note: variableRead might include false positives due to meta
// data operations like nrow(X) or operations removed by rewrites
// double check hops on basic blocks; otherwise worst-case
boolean skipRemove = false;
if (sb.getHops() != null) {
Hop.resetVisitStatus(sb.getHops());
skipRemove = true;
for (Hop root : sb.getHops()) skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
}
if (!skipRemove)
chkpointCand.remove(cand);
}
// prune candidates (updated in conditional control flow)
Set<String> cands2 = new HashSet<>(chkpointCand.keySet());
if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand)) {
chkpointCand.remove(cand);
}
} else // move checkpoint after update with simple read chain
// (note: right now this only applies if the checkpoints comes from a previous
// statement block, within-dag checkpoints should be handled during injection)
{
for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand) && sb.getHops() != null) {
Hop.resetVisitStatus(sb.getHops());
for (Hop root : sb.getHops()) if (root.getName().equals(cand)) {
if (HopRewriteUtils.rHasSimpleReadChain(root, cand)) {
chkpointCand.get(cand).setRequiresCheckpoint(false);
root.getInput().get(0).setRequiresCheckpoint(true);
chkpointCand.put(cand, root.getInput().get(0));
} else
chkpointCand.remove(cand);
}
}
}
// collect checkpoints
if (HopRewriteUtils.isLastLevelStatementBlock(sb)) {
ArrayList<Hop> tmp = collectCheckpoints(sb.getHops());
for (Hop chkpoint : tmp) chkpointCand.put(chkpoint.getName(), chkpoint);
}
}
}
use of org.apache.sysml.parser.IfStatementBlock in project incubator-systemml by apache.
the class SpoofCompiler method generateCodeFromStatementBlock.
public static void generateCodeFromStatementBlock(StatementBlock current) {
if (current instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) current;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
} else if (current instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) current;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wsb.setPredicateHops(optimize(wsb.getPredicateHops(), false));
for (StatementBlock sb : wstmt.getBody()) generateCodeFromStatementBlock(sb);
} else if (current instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) current;
IfStatement istmt = (IfStatement) isb.getStatement(0);
isb.setPredicateHops(optimize(isb.getPredicateHops(), false));
for (StatementBlock sb : istmt.getIfBody()) generateCodeFromStatementBlock(sb);
for (StatementBlock sb : istmt.getElseBody()) generateCodeFromStatementBlock(sb);
} else if (// incl parfor
current instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) current;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fsb.setFromHops(optimize(fsb.getFromHops(), false));
fsb.setToHops(optimize(fsb.getToHops(), false));
fsb.setIncrementHops(optimize(fsb.getIncrementHops(), false));
for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
} else // generic (last-level)
{
current.setHops(generateCodeFromHopDAGs(current.getHops()));
current.updateRecompilationFlag();
}
}
use of org.apache.sysml.parser.IfStatementBlock in project incubator-systemml by apache.
the class RewriteCompressedReblock method rAnalyzeProgram.
private static void rAnalyzeProgram(StatementBlock sb, ProbeStatus status) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock csb : wstmt.getBody()) rAnalyzeProgram(csb, status);
if (wsb.variablesRead().containsAnyName(status.compMtx))
status.usedInLoop = true;
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
for (StatementBlock csb : istmt.getIfBody()) rAnalyzeProgram(csb, status);
for (StatementBlock csb : istmt.getElseBody()) rAnalyzeProgram(csb, status);
if (isb.variablesUpdated().containsAnyName(status.compMtx))
status.condUpdate = true;
} else if (sb instanceof ForStatementBlock) {
// incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
if (fsb.variablesRead().containsAnyName(status.compMtx))
status.usedInLoop = true;
} else if (sb.getHops() != null) {
// generic (last-level)
ArrayList<Hop> roots = sb.getHops();
Hop.resetVisitStatus(roots);
// process entire HOP DAG starting from the roots
for (Hop root : roots) rAnalyzeHopDag(root, status);
// remove temporary variables
status.compMtx.removeIf(n -> n.startsWith(TMP_PREFIX));
Hop.resetVisitStatus(roots);
}
}
use of org.apache.sysml.parser.IfStatementBlock in project incubator-systemml by apache.
the class RewriteForLoopVectorization method rewriteStatementBlock.
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
if (sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
Hop from = fsb.getFromHops();
Hop to = fsb.getToHops();
Hop incr = fsb.getIncrementHops();
String iterVar = fsb.getIterPredicate().getIterVar().getName();
if (// single child block
fs.getBody() != null && fs.getBody().size() == 1) {
StatementBlock csb = (StatementBlock) fs.getBody().get(0);
if (!(// last level block
csb instanceof WhileStatementBlock || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock)) {
// AUTO VECTORIZATION PATTERNS
// Note: unnecessary row or column indexing then later removed via hop rewrites
// e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[i,2] = abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]);
sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[7,i] = Y[1,i]} -> X[7,a:b] = Y[1,a:b];
sb = vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar);
}
}
}
// that includes the equivalent vectorized operations.
return Arrays.asList(sb);
}
use of org.apache.sysml.parser.IfStatementBlock in project incubator-systemml by apache.
the class RewriteRemoveUnnecessaryBranches method rewriteStatementBlock.
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
ArrayList<StatementBlock> ret = new ArrayList<>();
if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
Hop pred = isb.getPredicateHops().getInput().get(0);
// apply rewrite if literal op (constant value)
if (pred instanceof LiteralOp) {
IfStatement istmt = (IfStatement) isb.getStatement(0);
LiteralOp litpred = (LiteralOp) pred;
boolean condition = HopRewriteUtils.getBooleanValue(litpred);
if (condition) {
// pull-out simple if body
if (!istmt.getIfBody().isEmpty())
// pull if-branch
ret.addAll(istmt.getIfBody());
// otherwise: add nothing (remove if-else)
} else {
// pull-out simple else body
if (!istmt.getElseBody().isEmpty())
// pull else-branch
ret.addAll(istmt.getElseBody());
// otherwise: add nothing (remove if-else)
}
state.setRemovedBranches();
LOG.debug("Applied removeUnnecessaryBranches (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ").");
} else
// keep original sb (non-constant condition)
ret.add(sb);
} else
// keep original sb (no if)
ret.add(sb);
return ret;
}
Aggregations