use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class HopRewriteUtils method copyDataGenOp.
/**
* Assumes that min and max are literal ops, needs to be checked from outside.
*
* @param inputGen input data gen op
* @param scale the scale
* @param shift the shift
* @return data gen op
*/
public static DataGenOp copyDataGenOp(DataGenOp inputGen, double scale, double shift) {
HashMap<String, Integer> params = inputGen.getParamIndexMap();
Hop rows = inputGen.getInput().get(params.get(DataExpression.RAND_ROWS));
Hop cols = inputGen.getInput().get(params.get(DataExpression.RAND_COLS));
Hop min = inputGen.getInput().get(params.get(DataExpression.RAND_MIN));
Hop max = inputGen.getInput().get(params.get(DataExpression.RAND_MAX));
Hop pdf = inputGen.getInput().get(params.get(DataExpression.RAND_PDF));
Hop mean = inputGen.getInput().get(params.get(DataExpression.RAND_LAMBDA));
Hop sparsity = inputGen.getInput().get(params.get(DataExpression.RAND_SPARSITY));
Hop seed = inputGen.getInput().get(params.get(DataExpression.RAND_SEED));
// check for literal ops
if (!(min instanceof LiteralOp) || !(max instanceof LiteralOp))
return null;
// scale and shift
double smin = getDoubleValue((LiteralOp) min);
double smax = getDoubleValue((LiteralOp) max);
smin = smin * scale + shift;
smax = smax * scale + shift;
Hop sminHop = new LiteralOp(smin);
Hop smaxHop = new LiteralOp(smax);
HashMap<String, Hop> params2 = new HashMap<>();
params2.put(DataExpression.RAND_ROWS, rows);
params2.put(DataExpression.RAND_COLS, cols);
params2.put(DataExpression.RAND_MIN, sminHop);
params2.put(DataExpression.RAND_MAX, smaxHop);
params2.put(DataExpression.RAND_PDF, pdf);
params2.put(DataExpression.RAND_LAMBDA, mean);
params2.put(DataExpression.RAND_SPARSITY, sparsity);
params2.put(DataExpression.RAND_SEED, seed);
// note internal refresh size information
DataGenOp datagen = new DataGenOp(DataGenMethod.RAND, new DataIdentifier("tmp"), params2);
datagen.setOutputBlocksizes(inputGen.getRowsInBlock(), inputGen.getColsInBlock());
copyLineNumbers(inputGen, datagen);
if (smin == 0 && smax == 0)
datagen.setNnz(0);
return datagen;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class HopRewriteUtils method isBinarySparseSafe.
public static boolean isBinarySparseSafe(Hop hop) {
if (!(hop instanceof BinaryOp))
return false;
if (isBinary(hop, OpOp2.MULT, OpOp2.DIV))
return true;
BinaryOp bop = (BinaryOp) hop;
Hop lit = bop.getInput().get(0) instanceof LiteralOp ? bop.getInput().get(0) : bop.getInput().get(1) instanceof LiteralOp ? bop.getInput().get(1) : null;
return lit != null && OptimizerUtils.isBinaryOpSparsityConditionalSparseSafe(bop.getOp(), (LiteralOp) lit);
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class HopRewriteUtils method createSeqDataGenOp.
public static DataGenOp createSeqDataGenOp(Hop input, boolean asc) {
Hop to = input.rowsKnown() ? new LiteralOp(input.getDim1()) : new UnaryOp("tmprows", DataType.SCALAR, ValueType.INT, OpOp1.NROW, input);
HashMap<String, Hop> params = new HashMap<>();
if (asc) {
params.put(Statement.SEQ_FROM, new LiteralOp(1));
params.put(Statement.SEQ_TO, to);
params.put(Statement.SEQ_INCR, new LiteralOp(1));
} else {
params.put(Statement.SEQ_FROM, to);
params.put(Statement.SEQ_TO, new LiteralOp(1));
params.put(Statement.SEQ_INCR, new LiteralOp(-1));
}
// note internal refresh size information
DataGenOp datagen = new DataGenOp(DataGenMethod.SEQ, new DataIdentifier("tmp"), params);
datagen.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
copyLineNumbers(input, datagen);
return datagen;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method removeEmptyRightIndexing.
private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) {
if (// indexing op
hi instanceof IndexingOp && hi.getDataType() == DataType.MATRIX) {
Hop input = hi.getInput().get(0);
if (// nnz input known and empty
input.getNnz() == 0 && // output dims known
HopRewriteUtils.isDimsKnown(hi)) {
// remove unnecessary right indexing
Hop hnew = HopRewriteUtils.createDataGenOpByVal(new LiteralOp(hi.getDim1()), new LiteralOp(hi.getDim2()), 0);
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);
hi = hnew;
LOG.debug("Applied removeEmptyRightIndexing");
}
}
return hi;
}
use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.
/**
* Searches for weighted squared loss expressions and replaces them with a quaternary operator.
* Currently, this search includes the following three patterns:
* 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
* 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
* 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
*
* NOTE: We include transpose into the pattern because during runtime we need to compute
* U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
* without additional memory requirements for internal transpose.
*
* This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
* U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
* rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
* as we generalized the runtime or hop/lop compilation.
*
* @param parent parent high-level operator
* @param hi high-level operator
* @param pos position
* @return high-level operator
*/
private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) {
// NOTE: there might be also a general simplification without custom operator
// via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
Hop hnew = null;
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && // all patterns rooted by sum()
((AggUnaryOp) hi).getOp() == AggOp.SUM && // all patterns subrooted by binary op
hi.getInput().get(0) instanceof BinaryOp && // not applied for vector-vector mult
hi.getInput().get(0).getDim2() > 1) {
BinaryOp bop = (BinaryOp) hi.getInput().get(0);
boolean appliedPattern = false;
// alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
Hop W = bop.getInput().get(0);
// (X - U %*% t(V))
Hop tmp = bop.getInput().get(1).getInput().get(0);
if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
// a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
tmp.getInput().get(1) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (W * (U %*% t(V) - X) ^ 2)
if (// ba gurantees matrices
tmp.getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
// handle special case of post_nz
if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
W = new LiteralOp(1);
}
// construct quaternary hop
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 1;
} else // b) sum ((W * (U %*% t(V)) - X) ^ 2)
if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
wuvIndex = 0;
}
if (// rewrite match
wuvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
// (W * (U %*% t(V)))
Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
Hop W = tmp.getInput().get(0);
Hop U = tmp.getInput().get(1).getInput().get(0);
Hop V = tmp.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && // prevent mv
HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
Hop lleft = bop.getInput().get(0).getInput().get(0);
Hop lright = bop.getInput().get(0).getInput().get(1);
// a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if (// ba gurantees matrices
lright instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
uvIndex = 1;
} else // b) sum (((U %*% t(V)) - X) ^ 2)
if (// ba gurantees matrices
lleft instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
uvIndex = 0;
}
if (// rewrite match
uvIndex >= 0) {
Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
// (U %*% t(V))
Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
// no weighting
Hop W = new LiteralOp(1);
Hop U = tmp.getInput().get(0);
Hop V = tmp.getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V)) {
V = HopRewriteUtils.createTranspose(V);
} else {
V = V.getInput().get(0);
}
hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
Aggregations