use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class HopRewriteUtils method createScalarIndexing.
public static Hop createScalarIndexing(Hop input, long rix, long cix) {
LiteralOp row = new LiteralOp(rix);
LiteralOp col = new LiteralOp(cix);
IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true);
ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
copyLineNumbers(input, ix);
ix.refreshSizeInformation();
return createUnary(ix, OpOp1.CAST_AS_SCALAR);
}
use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method removeEmptyRightIndexing.
private Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) throws HopsException {
if (//indexing op
hi instanceof IndexingOp && hi.getDataType() == DataType.MATRIX) {
Hop input = hi.getInput().get(0);
if (//nnz input known and empty
input.getNnz() == 0 && //output dims known
HopRewriteUtils.isDimsKnown(hi)) {
//remove unnecessary right indexing
Hop hnew = HopRewriteUtils.createDataGenOpByVal(new LiteralOp(hi.getDim1()), new LiteralOp(hi.getDim2()), 0);
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);
hi = hnew;
LOG.debug("Applied removeEmptyRightIndexing");
}
}
return hi;
}
use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method simplifySlicedMatrixMult.
private Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) throws HopsException {
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
if (hi instanceof IndexingOp && ((IndexingOp) hi).isRowLowerEqualsUpper() && ((IndexingOp) hi).isColLowerEqualsUpper() && //rix is single mm consumer
hi.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0))) {
Hop mm = hi.getInput().get(0);
Hop X = mm.getInput().get(0);
Hop Y = mm.getInput().get(1);
//rl==ru
Hop rowExpr = hi.getInput().get(1);
//cl==cu
Hop colExpr = hi.getInput().get(3);
HopRewriteUtils.removeAllChildReferences(mm);
//create new indexing operations
IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false);
ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
ix1.refreshSizeInformation();
IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock());
ix2.refreshSizeInformation();
//rewire matrix mult over ix1 and ix2
HopRewriteUtils.addChildReference(mm, ix1, 0);
HopRewriteUtils.addChildReference(mm, ix2, 1);
mm.refreshSizeInformation();
hi = mm;
LOG.debug("Applied simplifySlicedMatrixMult");
}
return hi;
}
use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class RewriteForLoopVectorization method vectorizeScalarAggregate.
private StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) throws HopsException {
StatementBlock ret = sb;
//check missing and supported increment values
if (!(increment != null && increment instanceof LiteralOp && ((LiteralOp) increment).getDoubleValue() == 1.0)) {
return ret;
}
//check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
//row or col
boolean rowIx = false;
if (csb.get_hops() != null && csb.get_hops().size() == 1) {
Hop root = csb.get_hops().get(0);
if (root.getDataType() == DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//check for left scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) right.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
leftScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
leftScalar = true;
rowIx = false;
}
} else //check for right scalar plus
if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
IndexingOp ix = (IndexingOp) left.getInput().get(0);
if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
rightScalar = true;
rowIx = true;
} else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
rightScalar = true;
rowIx = false;
}
}
}
}
//apply rewrite if possible
if (leftScalar || rightScalar) {
Hop root = csb.get_hops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
//replace cast with sum
AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
//modify indexing expression according to loop predicate from-to
//NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
//update indexing size information
if (rowIx)
((IndexingOp) ix).setRowLowerEqualsUpper(false);
else
((IndexingOp) ix).setColLowerEqualsUpper(false);
ix.refreshSizeInformation();
ret = csb;
LOG.debug("Applied vectorizeScalarSumForLoop.");
}
return ret;
}
use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.
the class RewriteIndexingVectorization method vectorizeLeftIndexing.
@SuppressWarnings("unchecked")
private void vectorizeLeftIndexing(Hop hop) throws HopsException {
if (//left indexing
hop instanceof LeftIndexingOp) {
LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
boolean isSingleCol = ihop0.getColLowerEqualsUpper();
boolean appliedRow = false;
if (isSingleRow && isSingleCol) {
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
Hop current = ihop0;
while (current.getInput().get(0) instanceof LeftIndexingOp) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if (//multiple consumers, i.e., not a simple chain
tmp.getParent().size() > 1 || //row merge not applicable
!((LeftIndexingOp) tmp).getRowLowerEqualsUpper() || //not the same row
tmp.getInput().get(2) != ihop0.getInput().get(2) || //target is single column or unknown
tmp.getInput().get(0).getDim2() <= 1) {
break;
}
ihops.add(tmp);
current = tmp;
}
//apply rewrite if found candidates
if (ihops.size() > 1) {
Hop input = current.getInput().get(0);
//keep before reset
Hop rowExpr = ihop0.getInput().get(2);
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewrite bottom left indexing operator
//input data
HopRewriteUtils.removeChildReference(current, input);
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset row index all candidates and refresh sizes (bottom-up)
for (int i = ihops.size() - 1; i >= 0; i--) {
Hop c = ihops.get(i);
//row lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2);
//row upper expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3);
((LeftIndexingOp) c).setRowLowerEqualsUpper(true);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
for (Hop parent : ihop0parents) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
//input data
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
for (int i = 0; i < ihop0parentsPos.size(); i++) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
appliedRow = true;
LOG.debug("Applied vectorizeLeftIndexingRow");
}
}
if (isSingleRow && isSingleCol && !appliedRow) {
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
Hop current = ihop0;
while (current.getInput().get(0) instanceof LeftIndexingOp) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if (//multiple consumers, i.e., not a simple chain
tmp.getParent().size() > 1 || //row merge not applicable
!((LeftIndexingOp) tmp).getColLowerEqualsUpper() || //not the same col
tmp.getInput().get(4) != ihop0.getInput().get(4) || //target is single row or unknown
tmp.getInput().get(0).getDim1() <= 1) {
break;
}
ihops.add(tmp);
current = tmp;
}
//apply rewrite if found candidates
if (ihops.size() > 1) {
Hop input = current.getInput().get(0);
//keep before reset
Hop colExpr = ihop0.getInput().get(4);
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewrite bottom left indexing operator
//input data
HopRewriteUtils.removeChildReference(current, input);
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset col index all candidates and refresh sizes (bottom-up)
for (int i = ihops.size() - 1; i >= 0; i--) {
Hop c = ihops.get(i);
//col lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4);
//col upper expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5);
((LeftIndexingOp) c).setColLowerEqualsUpper(true);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
for (Hop parent : ihop0parents) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
//input data
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp);
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
for (int i = 0; i < ihop0parentsPos.size(); i++) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
appliedRow = true;
LOG.debug("Applied vectorizeLeftIndexingCol");
}
}
}
}
Aggregations