use of org.apache.sysml.hops.NaryOp in project incubator-systemml by apache.
the class DMLTranslator method constructHops.
public void constructHops(StatementBlock sb) {
if (sb instanceof WhileStatementBlock) {
constructHopsForWhileControlBlock((WhileStatementBlock) sb);
return;
}
if (sb instanceof IfStatementBlock) {
constructHopsForIfControlBlock((IfStatementBlock) sb);
return;
}
if (sb instanceof ForStatementBlock) {
// incl ParForStatementBlock
constructHopsForForControlBlock((ForStatementBlock) sb);
return;
}
if (sb instanceof FunctionStatementBlock) {
constructHopsForFunctionControlBlock((FunctionStatementBlock) sb);
return;
}
HashMap<String, Hop> ids = new HashMap<>();
ArrayList<Hop> output = new ArrayList<>();
VariableSet liveIn = sb.liveIn();
VariableSet liveOut = sb.liveOut();
VariableSet updated = sb._updated;
VariableSet gen = sb._gen;
VariableSet updatedLiveOut = new VariableSet();
// handle liveout variables that are updated --> target identifiers for Assignment
HashMap<String, Integer> liveOutToTemp = new HashMap<>();
for (int i = 0; i < sb.getNumStatements(); i++) {
Statement current = sb.getStatement(i);
if (current instanceof AssignmentStatement) {
AssignmentStatement as = (AssignmentStatement) current;
DataIdentifier target = as.getTarget();
if (target != null) {
if (liveOut.containsVariable(target.getName())) {
liveOutToTemp.put(target.getName(), Integer.valueOf(i));
}
}
}
if (current instanceof MultiAssignmentStatement) {
MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
for (DataIdentifier target : mas.getTargetList()) {
if (liveOut.containsVariable(target.getName())) {
liveOutToTemp.put(target.getName(), Integer.valueOf(i));
}
}
}
}
// (i.e., from LV analysis, updated and gen sets)
if (!liveIn.getVariables().values().isEmpty()) {
for (String varName : liveIn.getVariables().keySet()) {
if (updated.containsVariable(varName) || gen.containsVariable(varName)) {
DataIdentifier var = liveIn.getVariables().get(varName);
long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier) var).getOrigDim1() : var.getDim1();
long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier) var).getOrigDim2() : var.getDim2();
DataOp read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), DataOpTypes.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getRowsInBlock(), var.getColumnsInBlock());
read.setParseInfo(var);
ids.put(varName, read);
}
}
}
for (int i = 0; i < sb.getNumStatements(); i++) {
Statement current = sb.getStatement(i);
if (current instanceof OutputStatement) {
OutputStatement os = (OutputStatement) current;
DataExpression source = os.getSource();
DataIdentifier target = os.getIdentifier();
// error handling unsupported indexing expression in write statement
if (target instanceof IndexedIdentifier) {
throw new LanguageException(source.printErrorLocation() + ": Unsupported indexing expression in write statement. " + "Please, assign the right indexing result to a variable and write this variable.");
}
DataOp ae = (DataOp) processExpression(source, target, ids);
String formatName = os.getExprParam(DataExpression.FORMAT_TYPE).toString();
ae.setInputFormatType(Expression.convertFormatType(formatName));
if (ae.getDataType() == DataType.SCALAR) {
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1);
} else {
switch(ae.getInputFormatType()) {
case TEXT:
case MM:
case CSV:
// write output in textcell format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1);
break;
case BINARY:
// write output in binary block format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize());
break;
default:
throw new LanguageException("Unrecognized file format: " + ae.getInputFormatType());
}
}
output.add(ae);
}
if (current instanceof PrintStatement) {
DataIdentifier target = createTarget();
target.setDataType(DataType.SCALAR);
target.setValueType(ValueType.STRING);
target.setParseInfo(current);
PrintStatement ps = (PrintStatement) current;
PRINTTYPE ptype = ps.getType();
try {
if (ptype == PRINTTYPE.PRINT) {
Hop.OpOp1 op = Hop.OpOp1.PRINT;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setParseInfo(current);
output.add(printHop);
} else if (ptype == PRINTTYPE.ASSERT) {
Hop.OpOp1 op = Hop.OpOp1.ASSERT;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setParseInfo(current);
output.add(printHop);
} else if (ptype == PRINTTYPE.STOP) {
Hop.OpOp1 op = Hop.OpOp1.STOP;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop stopHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
stopHop.setParseInfo(current);
output.add(stopHop);
// avoid merge
sb.setSplitDag(true);
} else if (ptype == PRINTTYPE.PRINTF) {
List<Expression> expressions = ps.getExpressions();
Hop[] inHops = new Hop[expressions.size()];
// Hop (ie, MultipleOp) as input Hops
for (int j = 0; j < expressions.size(); j++) {
Hop inHop = processExpression(expressions.get(j), target, ids);
inHops[j] = inHop;
}
target.setValueType(ValueType.STRING);
Hop printfHop = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.PRINTF, inHops);
output.add(printfHop);
}
} catch (HopsException e) {
throw new LanguageException(e);
}
}
if (current instanceof AssignmentStatement) {
AssignmentStatement as = (AssignmentStatement) current;
DataIdentifier target = as.getTarget();
Expression source = as.getSource();
// CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function
if (!(source instanceof FunctionCallIdentifier)) {
// CASE: target is regular data identifier
if (!(target instanceof IndexedIdentifier)) {
// process right hand side and accumulation
Hop ae = processExpression(source, target, ids);
if (((AssignmentStatement) current).isAccumulator()) {
DataIdentifier accum = liveIn.getVariable(target.getName());
if (accum == null)
throw new LanguageException("Invalid accumulator assignment " + "to non-existing variable " + target.getName() + ".");
ae = HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS);
target.setProperties(accum.getOutput());
} else
target.setProperties(source.getOutput());
ids.put(target.getName(), ae);
// add transient write if needed
Integer statementId = liveOutToTemp.get(target.getName());
if ((statementId != null) && (statementId.intValue() == i)) {
DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null);
transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock());
transientwrite.setParseInfo(target);
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
} else // CASE: target is indexed identifier (left-hand side indexed expression)
{
Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier) target, ids);
ids.put(target.getName(), ae);
// obtain origDim values BEFORE they are potentially updated during setProperties call
// (this is incorrect for LHS Indexing)
long origDim1 = ((IndexedIdentifier) target).getOrigDim1();
long origDim2 = ((IndexedIdentifier) target).getOrigDim2();
target.setProperties(source.getOutput());
((IndexedIdentifier) target).setOriginalDimensions(origDim1, origDim2);
// (required for scalar input to left indexing)
if (target.getDataType() != DataType.MATRIX) {
target.setDataType(DataType.MATRIX);
target.setValueType(ValueType.DOUBLE);
target.setBlockDimensions(ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize());
}
Integer statementId = liveOutToTemp.get(target.getName());
if ((statementId != null) && (statementId.intValue() == i)) {
DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null);
transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock());
transientwrite.setParseInfo(target);
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
}
} else {
// assignment, function call
FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(), fci.getName());
// error handling missing function
if (fsb == null) {
String error = source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace();
LOG.error(error);
throw new LanguageException(error);
}
// error handling unsupported function call in indexing expression
if (target instanceof IndexedIdentifier) {
String fkey = DMLProgram.constructFunctionKey(fci.getNamespace(), fci.getName());
throw new LanguageException("Unsupported function call to '" + fkey + "' in left indexing expression. " + "Please, assign the function output to a variable.");
}
ArrayList<Hop> finputs = new ArrayList<>();
for (ParameterExpression paramName : fci.getParamExprs()) {
Hop in = processExpression(paramName.getExpr(), null, ids);
finputs.add(in);
}
// create function op
FunctionType ftype = fsb.getFunctionOpType();
FunctionOp fcall = (target == null) ? new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, new String[] {}, false) : new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, new String[] { target.getName() }, false);
output.add(fcall);
// TODO function output dataops (phase 3)
// DataOp trFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), fcall, DataOpTypes.FUNCTIONOUTPUT, null);
// DataOp twFoutput = new DataOp(target.getName(), target.getDataType(), target.getValueType(), trFoutput, DataOpTypes.TRANSIENTWRITE, null);
}
} else if (current instanceof MultiAssignmentStatement) {
// multi-assignment, by definition a function call
MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
Expression source = mas.getSource();
if (source instanceof FunctionCallIdentifier) {
FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(), fci.getName());
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
if (fstmt == null) {
LOG.error(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace());
throw new LanguageException(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace());
}
ArrayList<Hop> finputs = new ArrayList<>();
for (ParameterExpression paramName : fci.getParamExprs()) {
Hop in = processExpression(paramName.getExpr(), null, ids);
finputs.add(in);
}
// create function op
String[] foutputs = mas.getTargetList().stream().map(d -> d.getName()).toArray(String[]::new);
FunctionType ftype = fsb.getFunctionOpType();
FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, foutputs, false);
output.add(fcall);
// TODO function output dataops (phase 3)
/*for ( DataIdentifier paramName : mas.getTargetList() ){
DataOp twFoutput = new DataOp(paramName.getName(), paramName.getDataType(), paramName.getValueType(), fcall, DataOpTypes.TRANSIENTWRITE, null);
output.add(twFoutput);
}*/
} else if (source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression) source).multipleReturns()) {
// construct input hops
Hop fcall = processMultipleReturnBuiltinFunctionExpression((BuiltinFunctionExpression) source, mas.getTargetList(), ids);
output.add(fcall);
} else if (source instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression) source).multipleReturns()) {
// construct input hops
Hop fcall = processMultipleReturnParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression) source, mas.getTargetList(), ids);
output.add(fcall);
} else
throw new LanguageException("Class \"" + source.getClass() + "\" is not supported in Multiple Assignment statements");
}
}
sb.updateLiveVariablesOut(updatedLiveOut);
sb.setHops(output);
}
use of org.apache.sysml.hops.NaryOp in project incubator-systemml by apache.
the class RewriteAlgebraicSimplificationStatic method foldMultipleAppendOperations.
private static Hop foldMultipleAppendOperations(Hop hi) {
if (// no string appends or frames
hi.getDataType().isMatrix() && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, OpOp2.RBIND) || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND)) && !OptimizerUtils.isHadoopExecutionMode()) {
OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp) hi).getOp() : OpOp2.valueOf(((NaryOp) hi).getOp().name());
OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp) hi).getOp() : OpOpN.valueOf(((BinaryOp) hi).getOp().name());
boolean converged = false;
while (!converged) {
// get first matching cbind or rbind
Hop first = hi.getInput().stream().filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)).findFirst().orElse(null);
// replace current op with new nary cbind/rbind
if (first != null && first.getParent().size() == 1) {
// construct new list of inputs (in original order)
ArrayList<Hop> linputs = new ArrayList<>();
for (Hop in : hi.getInput()) if (in == first)
linputs.addAll(first.getInput());
else
linputs.add(in);
Hop hnew = HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0]));
// clear dangling references
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(first);
// rewire all parents (avoid anomalies with refs to hi)
List<Hop> parents = new ArrayList<>(hi.getParent());
for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, hi, hnew);
hi = hnew;
LOG.debug("Applied foldMultipleAppendOperations (line " + hi.getBeginLine() + ").");
} else {
converged = true;
}
}
}
return hi;
}
use of org.apache.sysml.hops.NaryOp in project incubator-systemml by apache.
the class DMLTranslator method processBuiltinFunctionExpression.
/**
* Construct Hops from parse tree : Process BuiltinFunction Expression in an
* assignment statement
*
* @param source built-in function expression
* @param target data identifier
* @param hops map of high-level operators
* @return high-level operator
*/
private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, DataIdentifier target, HashMap<String, Hop> hops) {
Hop expr = processExpression(source.getFirstExpr(), null, hops);
Hop expr2 = null;
if (source.getSecondExpr() != null) {
expr2 = processExpression(source.getSecondExpr(), null, hops);
}
Hop expr3 = null;
if (source.getThirdExpr() != null) {
expr3 = processExpression(source.getThirdExpr(), null, hops);
}
Hop currBuiltinOp = null;
if (target == null) {
target = createTarget(source);
}
// Construct the hop based on the type of Builtin function
switch(source.getOpCode()) {
case EVAL:
currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
break;
case COLSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Col, expr);
break;
case COLMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Col, expr);
break;
case COLMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Col, expr);
break;
case COLMEAN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Col, expr);
break;
case COLSD:
// colStdDevs = sqrt(colVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case COLVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Col, expr);
break;
case ROWSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.Row, expr);
break;
case ROWMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.Row, expr);
break;
case ROWINDEXMAX:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAXINDEX, Direction.Row, expr);
break;
case ROWINDEXMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MININDEX, Direction.Row, expr);
break;
case ROWMIN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.Row, expr);
break;
case ROWMEAN:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.Row, expr);
break;
case ROWSD:
// rowStdDevs = sqrt(rowVariances)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case ROWVAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.Row, expr);
break;
case NROW:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
long nRows = expr.getDim1();
if (nRows == -1) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr);
} else {
currBuiltinOp = new LiteralOp(nRows);
}
break;
case NCOL:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
long nCols = expr.getDim2();
if (nCols == -1) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NCOL, expr);
} else {
currBuiltinOp = new LiteralOp(nCols);
}
break;
case LENGTH:
long nRows2 = expr.getDim1();
long nCols2 = expr.getDim2();
/*
* If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
* Else create a UnaryOp so that a control program instruction is generated
*/
if ((nCols2 == -1) || (nRows2 == -1)) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr);
} else {
long lval = (nCols2 * nRows2);
currBuiltinOp = new LiteralOp(lval);
}
break;
case SUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM, Direction.RowCol, expr);
break;
case MEAN:
if (expr2 == null) {
// example: x = mean(Y);
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN, Direction.RowCol, expr);
} else {
// example: x = mean(Y,W);
// stable weighted mean is implemented by using centralMoment with order = 0
Hop orderHop = new LiteralOp(0);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop);
}
break;
case SD:
// stdDev = sqrt(variance)
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
break;
case VAR:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
break;
case MIN:
// construct AggUnary for min(X) but BinaryOp for min(X,Y)
if (expr2 == null) {
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MIN, Direction.RowCol, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MIN, expr, expr2);
}
break;
case MAX:
// construct AggUnary for max(X) but BinaryOp for max(X,Y)
if (expr2 == null) {
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MAX, Direction.RowCol, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MAX, expr, expr2);
}
break;
case PPRED:
String sop = ((StringIdentifier) source.getThirdExpr()).getValue();
sop = sop.replace("\"", "");
OpOp2 operation;
if (sop.equalsIgnoreCase(">="))
operation = OpOp2.GREATEREQUAL;
else if (sop.equalsIgnoreCase(">"))
operation = OpOp2.GREATER;
else if (sop.equalsIgnoreCase("<="))
operation = OpOp2.LESSEQUAL;
else if (sop.equalsIgnoreCase("<"))
operation = OpOp2.LESS;
else if (sop.equalsIgnoreCase("=="))
operation = OpOp2.EQUAL;
else if (sop.equalsIgnoreCase("!="))
operation = OpOp2.NOTEQUAL;
else {
LOG.error(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
throw new ParseException(source.printErrorLocation() + "Unknown argument (" + sop + ") for PPRED.");
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2);
break;
case PROD:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.PROD, Direction.RowCol, expr);
break;
case TRACE:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.TRACE, Direction.RowCol, expr);
break;
case TRANS:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.TRANSPOSE, expr);
break;
case REV:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.REV, expr);
break;
case CBIND:
case RBIND:
OpOp2 appendOp1 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOp2.CBIND : OpOp2.RBIND;
OpOpN appendOp2 = (source.getOpCode() == BuiltinFunctionOp.CBIND) ? OpOpN.CBIND : OpOpN.RBIND;
currBuiltinOp = (source.getAllExpr().length == 2) ? new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp1, expr, expr2) : new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOp2, processAllExpressions(source.getAllExpr(), hops));
break;
case DIAG:
currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.DIAG, expr);
break;
case TABLE:
// Always a TertiaryOp is created for table().
// - create a hop for weights, if not provided in the function call.
int numTableArgs = source._args.length;
switch(numTableArgs) {
case 2:
case 4:
// example DML statement: F = ctable(A,B) or F = ctable(A,B,10,15)
// here, weight is interpreted as 1.0
Hop weightHop = new LiteralOp(1.0);
// set dimensions
weightHop.setDim1(0);
weightHop.setDim2(0);
weightHop.setNnz(-1);
weightHop.setRowsInBlock(0);
weightHop.setColsInBlock(0);
if (numTableArgs == 2)
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop);
else {
Hop outDim1 = processExpression(source._args[2], null, hops);
Hop outDim2 = processExpression(source._args[3], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2);
}
break;
case 3:
case 5:
// example DML statement: F = ctable(A,B,W) or F = ctable(A,B,W,10,15)
if (numTableArgs == 3)
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3);
else {
Hop outDim1 = processExpression(source._args[3], null, hops);
Hop outDim2 = processExpression(source._args[4], null, hops);
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2);
}
break;
default:
throw new ParseException("Invalid number of arguments " + numTableArgs + " to table() function.");
}
break;
// data type casts
case CAST_AS_SCALAR:
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
break;
case CAST_AS_MATRIX:
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
break;
case CAST_AS_FRAME:
currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
break;
// value type casts
case CAST_AS_DOUBLE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.DOUBLE, Hop.OpOp1.CAST_AS_DOUBLE, expr);
break;
case CAST_AS_INT:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT, Hop.OpOp1.CAST_AS_INT, expr);
break;
case CAST_AS_BOOLEAN:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
break;
// Boolean binary
case XOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.XOR, expr, expr2);
break;
case BITWAND:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWAND, expr, expr2);
break;
case BITWOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWOR, expr, expr2);
break;
case BITWXOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWXOR, expr, expr2);
break;
case BITWSHIFTL:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTL, expr, expr2);
break;
case BITWSHIFTR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.BITWSHIFTR, expr, expr2);
break;
case ABS:
case SIN:
case COS:
case TAN:
case ASIN:
case ACOS:
case ATAN:
case SINH:
case COSH:
case TANH:
case SIGN:
case SQRT:
case EXP:
case ROUND:
case CEIL:
case FLOOR:
case CUMSUM:
case CUMPROD:
case CUMMIN:
case CUMMAX:
Hop.OpOp1 mathOp1;
switch(source.getOpCode()) {
case ABS:
mathOp1 = Hop.OpOp1.ABS;
break;
case SIN:
mathOp1 = Hop.OpOp1.SIN;
break;
case COS:
mathOp1 = Hop.OpOp1.COS;
break;
case TAN:
mathOp1 = Hop.OpOp1.TAN;
break;
case ASIN:
mathOp1 = Hop.OpOp1.ASIN;
break;
case ACOS:
mathOp1 = Hop.OpOp1.ACOS;
break;
case ATAN:
mathOp1 = Hop.OpOp1.ATAN;
break;
case SINH:
mathOp1 = Hop.OpOp1.SINH;
break;
case COSH:
mathOp1 = Hop.OpOp1.COSH;
break;
case TANH:
mathOp1 = Hop.OpOp1.TANH;
break;
case SIGN:
mathOp1 = Hop.OpOp1.SIGN;
break;
case SQRT:
mathOp1 = Hop.OpOp1.SQRT;
break;
case EXP:
mathOp1 = Hop.OpOp1.EXP;
break;
case ROUND:
mathOp1 = Hop.OpOp1.ROUND;
break;
case CEIL:
mathOp1 = Hop.OpOp1.CEIL;
break;
case FLOOR:
mathOp1 = Hop.OpOp1.FLOOR;
break;
case CUMSUM:
mathOp1 = Hop.OpOp1.CUMSUM;
break;
case CUMPROD:
mathOp1 = Hop.OpOp1.CUMPROD;
break;
case CUMMIN:
mathOp1 = Hop.OpOp1.CUMMIN;
break;
case CUMMAX:
mathOp1 = Hop.OpOp1.CUMMAX;
break;
default:
LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp1, expr);
break;
case LOG:
if (expr2 == null) {
Hop.OpOp1 mathOp2;
switch(source.getOpCode()) {
case LOG:
mathOp2 = Hop.OpOp1.LOG;
break;
default:
LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
}
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp2, expr);
} else {
Hop.OpOp2 mathOp3;
switch(source.getOpCode()) {
case LOG:
mathOp3 = Hop.OpOp2.LOG;
break;
default:
LOG.error(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
throw new ParseException(source.printErrorLocation() + "processBuiltinFunctionExpression():: Could not find Operation type for builtin function: " + source.getOpCode());
}
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), mathOp3, expr, expr2);
}
break;
case MOMENT:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.CENTRALMOMENT, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.CENTRALMOMENT, expr, expr2, expr3);
}
break;
case COV:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.COVARIANCE, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.COVARIANCE, expr, expr2, expr3);
}
break;
case QUANTILE:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.QUANTILE, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.QUANTILE, expr, expr2, expr3);
}
break;
case INTERQUANTILE:
if (expr3 == null) {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.INTERQUANTILE, expr, expr2);
} else {
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.INTERQUANTILE, expr, expr2, expr3);
}
break;
case IQM:
if (expr2 == null) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.IQM, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.IQM, expr, expr2);
}
break;
case MEDIAN:
if (expr2 == null) {
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.MEDIAN, expr);
} else {
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.MEDIAN, expr, expr2);
}
break;
case IFELSE:
currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp3.IFELSE, expr, expr2, expr3);
break;
case SEQ:
HashMap<String, Hop> randParams = new HashMap<>();
randParams.put(Statement.SEQ_FROM, expr);
randParams.put(Statement.SEQ_TO, expr2);
randParams.put(Statement.SEQ_INCR, (expr3 != null) ? expr3 : new LiteralOp(1));
// note incr: default -1 (for from>to) handled during runtime
currBuiltinOp = new DataGenOp(DataGenMethod.SEQ, target, randParams);
break;
case SAMPLE:
{
Expression[] in = source.getAllExpr();
// arguments: range/size/replace/seed; defaults: replace=FALSE
HashMap<String, Hop> tmpparams = new HashMap<>();
// range
tmpparams.put(DataExpression.RAND_MAX, expr);
tmpparams.put(DataExpression.RAND_ROWS, expr2);
tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1));
if (in.length == 4) {
tmpparams.put(DataExpression.RAND_PDF, expr3);
Hop seed = processExpression(in[3], null, hops);
tmpparams.put(DataExpression.RAND_SEED, seed);
} else if (in.length == 3) {
// check if the third argument is "replace" or "seed"
if (expr3.getValueType() == ValueType.BOOLEAN) {
tmpparams.put(DataExpression.RAND_PDF, expr3);
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
} else if (expr3.getValueType() == ValueType.INT) {
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, expr3);
} else
throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample().");
} else if (in.length == 2) {
tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false));
tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED));
}
currBuiltinOp = new DataGenOp(DataGenMethod.SAMPLE, target, tmpparams);
break;
}
case SOLVE:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2);
break;
case INVERSE:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.INVERSE, expr);
break;
case CHOLESKY:
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.CHOLESKY, expr);
break;
case OUTER:
if (!(expr3 instanceof LiteralOp))
throw new HopsException("Operator for outer builtin function must be a constant: " + expr3);
OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp) expr3).getStringValue());
if (op == null)
throw new HopsException("Unsupported outer vector binary operation: " + ((LiteralOp) expr3).getStringValue());
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), op, expr, expr2);
// flag op as specific outer vector operation
((BinaryOp) currBuiltinOp).setOuterVectorOperation(true);
// force size reevaluation according to 'outer' flag otherwise danger of incorrect dims
currBuiltinOp.refreshSizeInformation();
break;
case CONV2D:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case BIAS_ADD:
{
ArrayList<Hop> inHops1 = new ArrayList<>();
inHops1.add(expr);
inHops1.add(expr2);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_ADD, inHops1);
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case BIAS_MULTIPLY:
{
ArrayList<Hop> inHops1 = new ArrayList<>();
inHops1.add(expr);
inHops1.add(expr2);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.BIAS_MULTIPLY, inHops1);
setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
break;
}
case AVG_POOL:
case MAX_POOL:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops);
if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL)
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1);
else
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case AVG_POOL_BACKWARD:
case MAX_POOL_BACKWARD:
{
Hop image = expr;
// process dout as well
ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops);
if (source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD)
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1);
else
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.AVG_POOLING_BACKWARD, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case CONV2D_BACKWARD_FILTER:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
case CONV2D_BACKWARD_DATA:
{
Hop image = expr;
ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 1, hops);
currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inHops1);
setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp);
break;
}
default:
throw new ParseException("Unsupported builtin function type: " + source.getOpCode());
}
boolean isConvolution = source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == BuiltinFunctionOp.MAX_POOL || source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || source.getOpCode() == BuiltinFunctionOp.AVG_POOL || source.getOpCode() == BuiltinFunctionOp.AVG_POOL_BACKWARD;
if (!isConvolution) {
// Since the dimension of output doesnot match that of input variable for these operations
setIdentifierParams(currBuiltinOp, source.getOutput());
}
currBuiltinOp.setParseInfo(source);
return currBuiltinOp;
}
use of org.apache.sysml.hops.NaryOp in project incubator-systemml by apache.
the class HopRewriteUtils method createNary.
public static NaryOp createNary(OpOpN op, Hop... inputs) {
Hop mainInput = inputs[0];
NaryOp nop = new NaryOp(mainInput.getName(), mainInput.getDataType(), mainInput.getValueType(), op, inputs);
nop.setOutputBlocksizes(mainInput.getRowsInBlock(), mainInput.getColsInBlock());
copyLineNumbers(mainInput, nop);
nop.refreshSizeInformation();
return nop;
}
Aggregations