use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeScalarAggregate.
private static StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
StatementBlock ret = sb;
// check missing and supported increment values
if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
// check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
// row or col
boolean rowIx = false;
if (csb.getHops() != null && csb.getHops().size() == 1) {
Hop root = csb.getHops().get(0);
if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
// check for left scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) right.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
leftScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
leftScalar = true;
rowIx = false;
}
} else // check for right scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) left.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
rightScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
rightScalar = true;
rowIx = false;
}
}
}
}
// apply rewrite if possible
if (leftScalar || rightScalar) {
Hop root = csb.getHops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
// replace cast with sum
AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
// modify indexing expression according to loop predicate from-to
// NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
// update indexing size information
if (rowIx)
((IndexingOp) ix).setRowLowerEqualsUpper(false);
else
((IndexingOp) ix).setColLowerEqualsUpper(false);
ix.refreshSizeInformation();
ret = csb;
LOG.debug("Applied vectorizeScalarSumForLoop.");
}
return ret;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteIndexingVectorization method vectorizeRightLeftIndexingChains.
private static Hop vectorizeRightLeftIndexingChains(Hop hi) {
// check for valid root operator
if (!(hi instanceof LeftIndexingOp && hi.getInput().get(1) instanceof IndexingOp && hi.getInput().get(1).getParent().size() == 1))
return hi;
LeftIndexingOp lix0 = (LeftIndexingOp) hi;
IndexingOp rix0 = (IndexingOp) hi.getInput().get(1);
if (!(lix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper()) || lix0.isRowLowerEqualsUpper() != rix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper() != rix0.isColLowerEqualsUpper())
return hi;
boolean row = lix0.isRowLowerEqualsUpper();
if (!((row ? HopRewriteUtils.isFullRowIndexing(lix0) : HopRewriteUtils.isFullColumnIndexing(lix0)) && (row ? HopRewriteUtils.isFullRowIndexing(rix0) : HopRewriteUtils.isFullColumnIndexing(rix0))))
return hi;
// determine consecutive left-right indexing chains for rows/columns
List<LeftIndexingOp> lix = new ArrayList<>();
lix.add(lix0);
List<IndexingOp> rix = new ArrayList<>();
rix.add(rix0);
LeftIndexingOp clix = lix0;
IndexingOp crix = rix0;
while (isConsecutiveLeftRightIndexing(clix, crix, clix.getInput().get(0)) && clix.getInput().get(0).getParent().size() == 1 && clix.getInput().get(0).getInput().get(1).getParent().size() == 1) {
clix = (LeftIndexingOp) clix.getInput().get(0);
crix = (IndexingOp) clix.getInput().get(1);
lix.add(clix);
rix.add(crix);
}
// rewrite pattern if at least two consecutive pairs
if (lix.size() >= 2) {
IndexingOp rixn = rix.get(rix.size() - 1);
Hop rlrix = rixn.getInput().get(1);
Hop rurix = row ? HopRewriteUtils.createBinary(rlrix, new LiteralOp(rix.size() - 1), OpOp2.PLUS) : rixn.getInput().get(2);
Hop clrix = rixn.getInput().get(3);
Hop curix = row ? rixn.getInput().get(4) : HopRewriteUtils.createBinary(clrix, new LiteralOp(rix.size() - 1), OpOp2.PLUS);
IndexingOp rixNew = HopRewriteUtils.createIndexingOp(rixn.getInput().get(0), rlrix, rurix, clrix, curix);
LeftIndexingOp lixn = lix.get(rix.size() - 1);
Hop rllix = lixn.getInput().get(2);
Hop rulix = row ? HopRewriteUtils.createBinary(rllix, new LiteralOp(lix.size() - 1), OpOp2.PLUS) : lixn.getInput().get(3);
Hop cllix = lixn.getInput().get(4);
Hop culix = row ? lixn.getInput().get(5) : HopRewriteUtils.createBinary(cllix, new LiteralOp(lix.size() - 1), OpOp2.PLUS);
LeftIndexingOp lixNew = HopRewriteUtils.createLeftIndexingOp(lixn.getInput().get(0), rixNew, rllix, rulix, cllix, culix);
// rewire parents and childs
HopRewriteUtils.replaceChildReference(hi.getParent().get(0), hi, lixNew);
for (int i = 0; i < lix.size(); i++) {
HopRewriteUtils.removeAllChildReferences(lix.get(i));
HopRewriteUtils.removeAllChildReferences(rix.get(i));
}
hi = lixNew;
LOG.debug("Applied vectorizeRightLeftIndexingChains (line " + hi.getBeginLine() + ")");
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteIndexingVectorization method vectorizeLeftIndexing.
@SuppressWarnings("unchecked")
private static Hop vectorizeLeftIndexing(Hop hop) {
Hop ret = hop;
if (// left indexing
hop instanceof LeftIndexingOp) {
LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
boolean isSingleCol = ihop0.isColLowerEqualsUpper();
boolean appliedRow = false;
if (isSingleRow && isSingleCol) {
// collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<>();
ihops.add(ihop0);
Hop current = ihop0;
while (current.getInput().get(0) instanceof LeftIndexingOp) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if (// multiple consumers, i.e., not a simple chain
tmp.getParent().size() > 1 || // row merge not applicable
!((LeftIndexingOp) tmp).isRowLowerEqualsUpper() || // not the same row
tmp.getInput().get(2) != ihop0.getInput().get(2) || // target is single column or unknown
tmp.getInput().get(0).getDim2() <= 1) {
break;
}
ihops.add(tmp);
current = tmp;
}
// apply rewrite if found candidates
if (ihops.size() > 1) {
Hop input = current.getInput().get(0);
// keep before reset
Hop rowExpr = ihop0.getInput().get(2);
// new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
// reset visit status of copied hops (otherwise hidden by left indexing)
for (Hop c : newRix.getInput()) c.resetVisitStatus();
// rewrite bottom left indexing operator
// input data
HopRewriteUtils.removeChildReference(current, input);
HopRewriteUtils.addChildReference(current, newRix, 0);
// reset row index all candidates and refresh sizes (bottom-up)
for (int i = ihops.size() - 1; i >= 0; i--) {
Hop c = ihops.get(i);
// row lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2);
// row upper expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3);
((LeftIndexingOp) c).setRowLowerEqualsUpper(true);
c.refreshSizeInformation();
}
// new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
// (note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
for (Hop parent : ihop0parents) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
// input data
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
// reset visit status of copied hops (otherwise hidden by left indexing)
for (Hop c : newLix.getInput()) c.resetVisitStatus();
for (int i = 0; i < ihop0parentsPos.size(); i++) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
appliedRow = true;
ret = newLix;
LOG.debug("Applied vectorizeLeftIndexingRow for hop " + hop.getHopID());
}
}
if (isSingleRow && isSingleCol && !appliedRow) {
// collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<>();
ihops.add(ihop0);
Hop current = ihop0;
while (current.getInput().get(0) instanceof LeftIndexingOp) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if (// multiple consumers, i.e., not a simple chain
tmp.getParent().size() > 1 || // row merge not applicable
!((LeftIndexingOp) tmp).isColLowerEqualsUpper() || // not the same col
tmp.getInput().get(4) != ihop0.getInput().get(4) || // target is single row or unknown
tmp.getInput().get(0).getDim1() <= 1) {
break;
}
ihops.add(tmp);
current = tmp;
}
// apply rewrite if found candidates
if (ihops.size() > 1) {
Hop input = current.getInput().get(0);
// keep before reset
Hop colExpr = ihop0.getInput().get(4);
// new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
// reset visit status of copied hops (otherwise hidden by left indexing)
for (Hop c : newRix.getInput()) c.resetVisitStatus();
// rewrite bottom left indexing operator
// input data
HopRewriteUtils.removeChildReference(current, input);
HopRewriteUtils.addChildReference(current, newRix, 0);
// reset col index all candidates and refresh sizes (bottom-up)
for (int i = ihops.size() - 1; i >= 0; i--) {
Hop c = ihops.get(i);
// col lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4);
// col upper expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5);
((LeftIndexingOp) c).setColLowerEqualsUpper(true);
c.refreshSizeInformation();
}
// new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
// (note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
for (Hop parent : ihop0parents) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
// input data
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
// reset visit status of copied hops (otherwise hidden by left indexing)
for (Hop c : newLix.getInput()) c.resetVisitStatus();
for (int i = 0; i < ihop0parentsPos.size(); i++) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
ret = newLix;
LOG.debug("Applied vectorizeLeftIndexingCol for hop " + hop.getHopID());
}
}
}
return ret;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteRemoveUnnecessaryBranches method rewriteStatementBlock.
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
ArrayList<StatementBlock> ret = new ArrayList<>();
if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
Hop pred = isb.getPredicateHops().getInput().get(0);
// apply rewrite if literal op (constant value)
if (pred instanceof LiteralOp) {
IfStatement istmt = (IfStatement) isb.getStatement(0);
LiteralOp litpred = (LiteralOp) pred;
boolean condition = HopRewriteUtils.getBooleanValue(litpred);
if (condition) {
// pull-out simple if body
if (!istmt.getIfBody().isEmpty())
// pull if-branch
ret.addAll(istmt.getIfBody());
// otherwise: add nothing (remove if-else)
} else {
// pull-out simple else body
if (!istmt.getElseBody().isEmpty())
// pull else-branch
ret.addAll(istmt.getElseBody());
// otherwise: add nothing (remove if-else)
}
state.setRemovedBranches();
LOG.debug("Applied removeUnnecessaryBranches (lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ").");
} else
// keep original sb (non-constant condition)
ret.add(sb);
} else
// keep original sb (no if)
ret.add(sb);
return ret;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteSplitDagDataDependentOperators method rCollectDataDependentOperators.
private void rCollectDataDependentOperators(Hop hop, ArrayList<Hop> cand) {
if (hop.isVisited())
return;
// prevent unnecessary dag split (dims known or no consumer operations)
boolean noSplitRequired = (hop.dimsKnown() || HopRewriteUtils.hasOnlyWriteParents(hop, true, true));
boolean investigateChilds = true;
// #1 removeEmpty
if (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp) hop).getOp() == ParamBuiltinOp.RMEMPTY && !noSplitRequired && !(hop.getParent().size() == 1 && hop.getParent().get(0) instanceof TernaryOp && ((TernaryOp) hop.getParent().get(0)).isMatrixIgnoreZeroRewriteApplicable())) {
ParameterizedBuiltinOp pbhop = (ParameterizedBuiltinOp) hop;
cand.add(pbhop);
investigateChilds = false;
// keep interesting consumer information, flag hops accordingly
boolean noEmptyBlocks = true;
boolean onlyPMM = true;
boolean diagInput = pbhop.isTargetDiagInput();
for (Hop p : hop.getParent()) {
// list of operators without need for empty blocks to be extended as needed
noEmptyBlocks &= (p instanceof AggBinaryOp && hop == p.getInput().get(0) || HopRewriteUtils.isUnary(p, OpOp1.NROW));
onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
}
pbhop.setOutputEmptyBlocks(!noEmptyBlocks);
if (onlyPMM && diagInput) {
if (ConfigurationManager.isDynamicRecompilation())
pbhop.setOutputPermutationMatrix(true);
for (Hop p : hop.getParent()) ((AggBinaryOp) p).setHasLeftPMInput(true);
}
}
// #2 ctable with unknown dims
if (HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) && // dims not provided
hop.getInput().size() < 4 && !noSplitRequired) {
cand.add(hop);
investigateChilds = false;
// keep interesting consumer information, flag hops accordingly
boolean onlyPMM = true;
for (Hop p : hop.getParent()) {
onlyPMM &= (p instanceof AggBinaryOp && hop == p.getInput().get(0));
}
if (onlyPMM && HopRewriteUtils.isBasic1NSequence(hop.getInput().get(0)))
hop.setOutputEmptyBlocks(false);
}
// #3 orderby childs computed in same DAG
if (HopRewriteUtils.isReorg(hop, ReOrgOp.SORT)) {
// params 'decreasing' / 'indexreturn'
for (int i = 2; i <= 3; i++) {
Hop c = hop.getInput().get(i);
if (!(c instanceof LiteralOp || c instanceof DataOp)) {
cand.add(c);
c.setVisited();
investigateChilds = false;
}
}
}
// #4 second-order eval function
if (HopRewriteUtils.isNary(hop, OpOpN.EVAL) && !noSplitRequired) {
cand.add(hop);
investigateChilds = false;
}
// otherwise, processed by recursive rule application)
if (investigateChilds && hop.getInput() != null)
for (Hop c : hop.getInput()) rCollectDataDependentOperators(c, cand);
hop.setVisited();
}
Aggregations