use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.
the class RewriteCompressedReblock method rAnalyzeHopDag.
private static void rAnalyzeHopDag(Hop current, ProbeStatus status) {
if (current.isVisited())
return;
// process children recursively
for (Hop input : current.getInput()) rAnalyzeHopDag(input, status);
// handle source persistent read
if (current.getHopID() == status.startHopID) {
status.compMtx.add(getTmpName(current));
status.foundStart = true;
}
// a) handle function calls
if (current instanceof FunctionOp && hasCompressedInput(current, status)) {
// TODO handle of functions in a more fine-grained manner
// to cover special cases multiple calls where compressed
// inputs might occur for different input parameters
FunctionOp fop = (FunctionOp) current;
String fkey = fop.getFunctionKey();
if (!status.procFn.contains(fkey)) {
// memoization to avoid redundant analysis and recursive calls
status.procFn.add(fkey);
// map inputs to function inputs
FunctionStatementBlock fsb = status.prog.getFunctionStatementBlock(fkey);
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
ProbeStatus status2 = new ProbeStatus(status);
for (int i = 0; i < fop.getInput().size(); i++) if (status.compMtx.contains(getTmpName(fop.getInput().get(i))))
status2.compMtx.add(fstmt.getInputParams().get(i).getName());
// analyze function and merge meta info
rAnalyzeProgram(fsb, status2);
status.foundStart |= status2.foundStart;
status.usedInLoop |= status2.usedInLoop;
status.condUpdate |= status2.condUpdate;
status.nonApplicable |= status2.nonApplicable;
// map function outputs to outputs
String[] outputs = fop.getOutputVariableNames();
for (int i = 0; i < outputs.length; i++) if (status2.compMtx.contains(fstmt.getOutputParams().get(i).getName()))
status.compMtx.add(outputs[i]);
}
} else // b) handle transient reads and writes (name mapping)
if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) && status.compMtx.contains(getTmpName(current.getInput().get(0))))
status.compMtx.add(current.getName());
else if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) && status.compMtx.contains(current.getName()))
status.compMtx.add(getTmpName(current));
else // c) handle applicable operations
if (hasCompressedInput(current, status)) {
// valid with uncompressed outputs
boolean compUCOut = (// tsmm
current instanceof AggBinaryOp && current.getDim2() <= current.getColsInBlock() && ((AggBinaryOp) current).checkTransposeSelf() == MMTSJType.LEFT) || // mvmm
(current instanceof AggBinaryOp && (current.getDim1() == 1 || current.getDim2() == 1)) || (HopRewriteUtils.isTransposeOperation(current) && current.getParent().size() == 1 && current.getParent().get(0) instanceof AggBinaryOp && (current.getParent().get(0).getDim1() == 1 || current.getParent().get(0).getDim2() == 1)) || HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX);
// valid with compressed outputs
boolean compCOut = HopRewriteUtils.isBinaryMatrixScalarOperation(current) || HopRewriteUtils.isBinary(current, OpOp2.CBIND);
boolean metaOp = HopRewriteUtils.isUnary(current, OpOp1.NROW, OpOp1.NCOL);
status.nonApplicable |= !(compUCOut || compCOut || metaOp);
if (compCOut)
status.compMtx.add(getTmpName(current));
}
current.setVisited();
}
use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyRowSumsMVMult.
private static Hop simplifyRowSumsMVMult(Hop parent, Hop hi, int pos) {
// removed by other rewrite if unnecessary, i.e., if Y==t(Z)
if (hi instanceof AggUnaryOp) {
AggUnaryOp uhi = (AggUnaryOp) hi;
Hop input = uhi.getInput().get(0);
if (// rowsums
uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row && // b(*)
HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
Hop left = input.getInput().get(0);
Hop right = input.getInput().get(1);
if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() == 1 && // MV (row vector)
right.getDim2() > 1) {
// create new operators
ReorgOp trans = HopRewriteUtils.createTranspose(right);
AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
// relink new child
HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, input);
hi = mmult;
LOG.debug("Applied simplifyRowSumsMVMult");
}
}
}
return hi;
}
use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.
private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
Hop hnew = null;
if (// all patterns subrooted by W *
HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
hi.getDim2() > 1 && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
hi.getInput().get(1) instanceof UnaryOp) {
UnaryOp uop = (UnaryOp) hi.getInput().get(1);
boolean appliedPattern = false;
// Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
}
// Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
}
// Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop Y = bop.getInput().get(1).getInput().get(0);
Hop tX = bop.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(tX)) {
tX = HopRewriteUtils.createTranspose(tX);
} else
tX = tX.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
}
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.
the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedUnaryMM.
private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) {
Hop hnew = null;
boolean appliedPattern = false;
// Pattern 1) (W*uop(U%*%t(V)))
if (hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof UnaryOp && HopRewriteUtils.isValidOp(((UnaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0), true)) {
Hop W = hi.getInput().get(0);
Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
OpOp1 op = ((UnaryOp) hi.getInput().get(1)).getOp();
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, op, null);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line " + hi.getBeginLine() + ")");
}
// Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), OpOp2.MULT) && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) {
// non-literal
final Hop nl;
if (hi.getInput().get(0) instanceof LiteralOp) {
nl = hi.getInput().get(1);
} else {
nl = hi.getInput().get(0);
}
if (HopRewriteUtils.isBinary(nl, OpOp2.MULT) && // ensure no foreign parents
nl.getParent().size() == 1 && // prevent mv
HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) && // not applied for vector-vector mult
nl.getDim2() > 1 && nl.getInput().get(0).getDataType() == DataType.MATRIX && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) && // no mmchain
(((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0), true)) {
final Hop W = nl.getInput().get(0);
final Hop U = nl.getInput().get(1).getInput().get(0);
Hop V = nl.getInput().get(1).getInput().get(1);
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line " + hi.getBeginLine() + ")");
}
}
// Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if (!appliedPattern && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && // prevent mv
HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && // not applied for vector-vector mult
hi.getDim2() > 1 && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() && hi.getInput().get(1) instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp) hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY)) {
Hop left = hi.getInput().get(1).getInput().get(0);
Hop right = hi.getInput().get(1).getInput().get(1);
Hop abop = null;
// pattern 2a) matrix-scalar operations
if (right.getDataType() == DataType.SCALAR && right instanceof LiteralOp && // pow2, mult2
HopRewriteUtils.getDoubleValue((LiteralOp) right) == 2 && left instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(left.getInput().get(0), true)) {
abop = left;
} else // pattern 2b) scalar-matrix operations
if (left.getDataType() == DataType.SCALAR && left instanceof LiteralOp && // mult2
HopRewriteUtils.getDoubleValue((LiteralOp) left) == 2 && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MULT && right instanceof AggBinaryOp && // BLOCKSIZE CONSTRAINT
HopRewriteUtils.isSingleBlock(right.getInput().get(0), true)) {
abop = right;
}
if (abop != null) {
Hop W = hi.getInput().get(0);
Hop U = abop.getInput().get(0);
Hop V = abop.getInput().get(1);
boolean mult = ((BinaryOp) hi).getOp() == OpOp2.MULT;
OpOp2 op = ((BinaryOp) hi.getInput().get(1)).getOp();
if (!HopRewriteUtils.isTransposeOperation(V))
V = HopRewriteUtils.createTranspose(V);
else
V = V.getInput().get(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WUMM, W, U, V, mult, null, op);
hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
hnew.refreshSizeInformation();
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM2 (line " + hi.getBeginLine() + ")");
}
}
// relink new hop into original position
if (hnew != null) {
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
}
return hi;
}
use of org.apache.sysml.hops.AggBinaryOp in project systemml by apache.
the class HopRewriteUtils method createMatrixMultiply.
public static AggBinaryOp createMatrixMultiply(Hop left, Hop right) {
AggBinaryOp mmult = new AggBinaryOp(left.getName(), left.getDataType(), left.getValueType(), OpOp2.MULT, AggOp.SUM, left, right);
mmult.setOutputBlocksizes(left.getRowsInBlock(), right.getColsInBlock());
copyLineNumbers(left, mmult);
mmult.refreshSizeInformation();
return mmult;
}
Aggregations