use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class RewriteForLoopVectorization method rewriteStatementBlock.
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
if (sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
Hop from = fsb.getFromHops();
Hop to = fsb.getToHops();
Hop incr = fsb.getIncrementHops();
String iterVar = fsb.getIterPredicate().getIterVar().getName();
if (// single child block
fs.getBody() != null && fs.getBody().size() == 1) {
StatementBlock csb = (StatementBlock) fs.getBody().get(0);
if (!(// last level block
csb instanceof WhileStatementBlock || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock)) {
// AUTO VECTORIZATION PATTERNS
// Note: unnecessary row or column indexing then later removed via hop rewrites
// e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2])
sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[i,2] = abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]);
sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
// e.g., for(i in a:b){X[7,i] = Y[1,i]} -> X[7,a:b] = Y[1,a:b];
sb = vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar);
}
}
}
// that includes the equivalent vectorized operations.
return Arrays.asList(sb);
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeElementwiseUnary.
private static StatementBlock vectorizeElementwiseUnary(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
StatementBlock ret = sb;
// check supported increment values
if (!(increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
// check for applicability
boolean apply = false;
// row or col
boolean rowIx = false;
if (csb.getHops() != null && csb.getHops().size() == 1) {
Hop root = csb.getHops().get(0);
if (root.getDataType() == DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
Hop lixlhs = lix.getInput().get(0);
Hop lixrhs = lix.getInput().get(1);
if (lixlhs instanceof DataOp && lixrhs instanceof UnaryOp && lixrhs.getInput().get(0) instanceof IndexingOp && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp) {
boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp) lixrhs.getInput().get(0), itervar);
apply = tmp[0];
rowIx = tmp[1];
}
}
}
// apply rewrite if possible
if (apply) {
Hop root = csb.getHops().get(0);
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
UnaryOp uop = (UnaryOp) lix.getInput().get(1);
IndexingOp rix = (IndexingOp) uop.getInput().get(0);
int index1 = rowIx ? 2 : 4;
int index2 = rowIx ? 3 : 5;
// modify left indexing bounds
HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
// modify right indexing
HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
updateLeftAndRightIndexingSizes(rowIx, lix, rix);
uop.refreshSizeInformation();
// after uop update
lix.refreshSizeInformation();
ret = csb;
LOG.debug("Applied vectorizeElementwiseUnaryForLoop.");
}
return ret;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeIndexedCopy.
private static StatementBlock vectorizeIndexedCopy(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
StatementBlock ret = sb;
// check supported increment values
if (!(increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
// check for applicability
boolean apply = false;
// row or col
boolean rowIx = false;
if (csb.getHops() != null && csb.getHops().size() == 1) {
Hop root = csb.getHops().get(0);
if (root.getDataType() == DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
Hop lixlhs = lix.getInput().get(0);
Hop lixrhs = lix.getInput().get(1);
if (lixlhs instanceof DataOp && lixrhs instanceof IndexingOp && lixrhs.getInput().get(0) instanceof DataOp) {
boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp) lixrhs, itervar);
apply = tmp[0];
rowIx = tmp[1];
}
}
}
// apply rewrite if possible
if (apply) {
Hop root = csb.getHops().get(0);
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
IndexingOp rix = (IndexingOp) lix.getInput().get(1);
int index1 = rowIx ? 2 : 4;
int index2 = rowIx ? 3 : 5;
// modify left indexing bounds
HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
// modify right indexing
HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
updateLeftAndRightIndexingSizes(rowIx, lix, rix);
ret = csb;
LOG.debug("Applied vectorizeIndexedCopy.");
}
return ret;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeScalarAggregate.
private static StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
StatementBlock ret = sb;
// check missing and supported increment values
if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
// check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
// row or col
boolean rowIx = false;
if (csb.getHops() != null && csb.getHops().size() == 1) {
Hop root = csb.getHops().get(0);
if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// check for left scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) right.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
leftScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
leftScalar = true;
rowIx = false;
}
} else // check for right scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) left.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
rightScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
rightScalar = true;
rowIx = false;
}
}
}
}
// apply rewrite if possible
if (leftScalar || rightScalar) {
Hop root = csb.getHops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
// replace cast with sum
AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
// modify indexing expression according to loop predicate from-to
// NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
// update indexing size information
if (rowIx)
((IndexingOp) ix).setRowLowerEqualsUpper(false);
else
((IndexingOp) ix).setColLowerEqualsUpper(false);
ix.refreshSizeInformation();
ret = csb;
LOG.debug("Applied vectorizeScalarSumForLoop.");
}
return ret;
}
use of org.apache.sysml.parser.StatementBlock in project incubator-systemml by apache.
the class RewriteMergeBlockSequence method rewriteStatementBlocks.
@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
if (sbs == null || sbs.isEmpty())
return sbs;
// execute binary merging iterations until fixpoint
ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
boolean merged = true;
while (merged) {
merged = false;
for (int i = 0; i < tmpList.size() - 1; i++) {
StatementBlock sb1 = tmpList.get(i);
StatementBlock sb2 = tmpList.get(i + 1);
if (HopRewriteUtils.isLastLevelStatementBlock(sb1) && HopRewriteUtils.isLastLevelStatementBlock(sb2) && !sb1.isSplitDag() && !sb2.isSplitDag() && !(hasExternalFunctionOpRootWithSideEffect(sb1) && hasExternalFunctionOpRootWithSideEffect(sb2)) && (!hasFunctionOpRoot(sb1) || !hasFunctionIOConflict(sb1, sb2)) && (!hasFunctionOpRoot(sb2) || !hasFunctionIOConflict(sb2, sb1))) {
// note: we intend to merge sb1 into sb2 to connect data dependencies
// however, we work with a temporary list of root nodes to preserve
// the original order of roots, which affects prints w/o dependencies
ArrayList<Hop> sb1Hops = sb1.getHops();
ArrayList<Hop> sb2Hops = sb2.getHops();
ArrayList<Hop> newHops = new ArrayList<>();
// determine transient read inputs s2
Hop.resetVisitStatus(sb2Hops);
HashMap<String, Hop> treads = new HashMap<>();
HashMap<String, Hop> twrites = new HashMap<>();
for (Hop root : sb2Hops) rCollectTransientReadWrites(root, treads, twrites);
Hop.resetVisitStatus(sb2Hops);
// merge hop dags of s1 and s2
Hop.resetVisitStatus(sb1Hops);
for (Hop root : sb1Hops) {
// connect transient writes s1 and reads s2
if (HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) && treads.containsKey(root.getName())) {
// rewire transient write and transient read
Hop tread = treads.get(root.getName());
Hop in = root.getInput().get(0);
for (Hop parent : new ArrayList<>(tread.getParent())) HopRewriteUtils.replaceChildReference(parent, tread, in);
HopRewriteUtils.removeAllChildReferences(root);
// add transient write if necessary
if (!twrites.containsKey(root.getName()) && sb2.liveOut().containsVariable(root.getName())) {
newHops.add(HopRewriteUtils.createDataOp(root.getName(), in, DataOpTypes.TRANSIENTWRITE));
}
} else // add remaining roots from s1 to s2
if (!(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) && (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName())))) {
newHops.add(root);
}
}
// clear partial hops from the merged statement block to avoid problems with
// other statement block rewrites that iterate over the original program
sb1Hops.clear();
// append all root nodes of s2 after root nodes of s1
newHops.addAll(sb2Hops);
sb2.setHops(newHops);
// run common-subexpression elimination
Hop.resetVisitStatus(sb2.getHops());
rewriter.rewriteHopDAG(sb2.getHops(), new ProgramRewriteStatus());
// modify live variable sets of s2
// liveOut remains unchanged
sb2.setLiveIn(sb1.liveIn());
sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
LOG.debug("Applied mergeStatementBlockSequences " + "(blocks of lines " + sb1.getBeginLine() + "-" + sb1.getEndLine() + " and " + sb2.getBeginLine() + "-" + sb2.getEndLine() + ").");
// modify line numbers of s2
sb2.setBeginLine(sb1.getBeginLine());
sb2.setBeginColumn(sb1.getBeginColumn());
// remove sb1 from list of statement blocks
tmpList.remove(i);
merged = true;
// for
break;
}
}
}
return tmpList;
}
Aggregations