use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class ProgramRewriter method rRewriteStatementBlock.
public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status, boolean splitDags) {
ArrayList<StatementBlock> ret = new ArrayList<>();
ret.add(sb);
// recursive invocation
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
} else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
wstmt.setBody(rRewriteStatementBlocks(wstmt.getBody(), status, splitDags));
} else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
istmt.setIfBody(rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags));
istmt.setElseBody(rRewriteStatementBlocks(istmt.getElseBody(), status, splitDags));
} else if (sb instanceof ForStatementBlock) {
// incl parfor
// maintain parfor context information (e.g., for checkpointing)
boolean prestatus = status.isInParforContext();
if (sb instanceof ParForStatementBlock)
status.setInParforContext(true);
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags));
status.setInParforContext(prestatus);
}
// apply rewrite rules to individual statement blocks
for (StatementBlockRewriteRule r : _sbRuleSet) {
if (!splitDags && r.createsSplitDag())
continue;
ArrayList<StatementBlock> tmp = new ArrayList<>();
for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status));
// take over set of rewritten sbs
ret.clear();
ret.addAll(tmp);
}
return ret;
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class DmlSyntacticValidator method exitInternalFunctionDefExpression.
// -----------------------------------------------------------------
// Internal & External Functions Definitions
// -----------------------------------------------------------------
@Override
public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
FunctionStatement functionStmt = new FunctionStatement();
ArrayList<DataIdentifier> functionInputs = getFunctionParameters(ctx.inputParams);
functionStmt.setInputParams(functionInputs);
// set function outputs
ArrayList<DataIdentifier> functionOutputs = getFunctionParameters(ctx.outputParams);
functionStmt.setOutputParams(functionOutputs);
// set function name
functionStmt.setName(ctx.name.getText());
if (ctx.body.size() > 0) {
// handle function body
// Create arraylist of one statement block
ArrayList<StatementBlock> body = new ArrayList<>();
for (StatementContext stmtCtx : ctx.body) {
body.add(getStatementBlock(stmtCtx.info.stmt));
}
functionStmt.setBody(body);
functionStmt.mergeStatementBlocks();
} else {
notifyErrorListeners("functions with no statements are not allowed", ctx.start);
return;
}
ctx.info.stmt = functionStmt;
setFileLineColumn(ctx.info.stmt, ctx);
ctx.info.functionName = ctx.name.getText();
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class PydmlSyntacticValidator method exitInternalFunctionDefExpression.
@Override
public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
FunctionStatement functionStmt = new FunctionStatement();
ArrayList<DataIdentifier> functionInputs = getFunctionParameters(ctx.inputParams);
functionStmt.setInputParams(functionInputs);
// set function outputs
ArrayList<DataIdentifier> functionOutputs = getFunctionParameters(ctx.outputParams);
functionStmt.setOutputParams(functionOutputs);
// set function name
functionStmt.setName(ctx.name.getText());
if (ctx.body.size() > 0) {
// handle function body
// Create arraylist of one statement block
ArrayList<StatementBlock> body = new ArrayList<>();
for (StatementContext stmtCtx : ctx.body) {
body.add(getStatementBlock(stmtCtx.info.stmt));
}
functionStmt.setBody(body);
functionStmt.mergeStatementBlocks();
} else {
notifyErrorListeners("functions with no statements are not allowed", ctx.start);
return;
}
ctx.info.stmt = functionStmt;
setFileLineColumn(ctx.info.stmt, ctx);
ctx.info.functionName = ctx.name.getText();
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class OptTreeConverter method rCreateAbstractOptNodes.
public static ArrayList<OptNode> rCreateAbstractOptNodes(Hop hop, LocalVariableMap vars, Set<String> memo) {
ArrayList<OptNode> ret = new ArrayList<>();
ArrayList<Hop> in = hop.getInput();
if (hop.isVisited())
return ret;
// general case
if (!(hop instanceof DataOp || hop instanceof LiteralOp || hop instanceof FunctionOp)) {
OptNode node = new OptNode(NodeType.HOP);
String opstr = hop.getOpString();
node.addParam(ParamType.OPSTRING, opstr);
// handle execution type
LopProperties.ExecType et = (hop.getExecType() != null) ? hop.getExecType() : LopProperties.ExecType.CP;
switch(et) {
case CP:
case GPU:
node.setExecType(ExecType.CP);
break;
case SPARK:
node.setExecType(ExecType.SPARK);
break;
case MR:
node.setExecType(ExecType.MR);
break;
default:
throw new DMLRuntimeException("Unsupported optnode exec type: " + et);
}
// handle degree of parallelism
if (et == LopProperties.ExecType.CP && hop instanceof MultiThreadedHop) {
MultiThreadedHop mtop = (MultiThreadedHop) hop;
node.setK(OptimizerUtils.getConstrainedNumThreads(mtop.getMaxNumThreads()));
}
// assign node to return
_hlMap.putHopMapping(hop, node);
ret.add(node);
} else // process function calls
if (hop instanceof FunctionOp && INCLUDE_FUNCTIONS) {
FunctionOp fhop = (FunctionOp) hop;
String fname = fhop.getFunctionName();
String fnspace = fhop.getFunctionNamespace();
String fKey = fhop.getFunctionKey();
Object[] prog = _hlMap.getRootProgram();
OptNode node = new OptNode(NodeType.FUNCCALL);
_hlMap.putHopMapping(fhop, node);
node.setExecType(ExecType.CP);
node.addParam(ParamType.OPSTRING, fKey);
if (!fnspace.equals(DMLProgram.INTERNAL_NAMESPACE)) {
FunctionProgramBlock fpb = ((Program) prog[1]).getFunctionProgramBlock(fnspace, fname);
FunctionStatementBlock fsb = ((DMLProgram) prog[0]).getFunctionStatementBlock(fnspace, fname);
FunctionStatement fs = (FunctionStatement) fsb.getStatement(0);
// process body; NOTE: memo prevents inclusion of functions multiple times
if (!memo.contains(fKey)) {
memo.add(fKey);
int len = fs.getBody().size();
for (int i = 0; i < fpb.getChildBlocks().size() && i < len; i++) {
ProgramBlock lpb = fpb.getChildBlocks().get(i);
StatementBlock lsb = fs.getBody().get(i);
node.addChild(rCreateAbstractOptNode(lsb, lpb, vars, false, memo));
}
memo.remove(fKey);
} else
node.addParam(ParamType.RECURSIVE_CALL, "true");
}
ret.add(node);
}
if (in != null)
for (Hop hin : in) if (// no need for opt nodes
!(hin instanceof DataOp || hin instanceof LiteralOp))
ret.addAll(rCreateAbstractOptNodes(hin, vars, memo));
hop.setVisited();
return ret;
}
use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.
the class OptTreePlanChecker method checkProgramCorrectness.
public static void checkProgramCorrectness(ProgramBlock pb, StatementBlock sb, Set<String> fnStack) {
Program prog = pb.getProgram();
DMLProgram dprog = sb.getDMLProg();
if (pb instanceof FunctionProgramBlock && sb instanceof FunctionStatementBlock) {
FunctionProgramBlock fpb = (FunctionProgramBlock) pb;
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
for (int i = 0; i < fpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
// checkLinksProgramStatementBlock(fpb, fsb);
} else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
checkHopDagCorrectness(prog, dprog, wsb.getPredicateHops(), wpb.getPredicate(), fnStack);
for (int i = 0; i < wpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = wpb.getChildBlocks().get(i);
StatementBlock sbc = wstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(wpb, wsb);
} else if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
checkHopDagCorrectness(prog, dprog, isb.getPredicateHops(), ipb.getPredicate(), fnStack);
for (int i = 0; i < ipb.getChildBlocksIfBody().size(); i++) {
ProgramBlock pbc = ipb.getChildBlocksIfBody().get(i);
StatementBlock sbc = istmt.getIfBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
for (int i = 0; i < ipb.getChildBlocksElseBody().size(); i++) {
ProgramBlock pbc = ipb.getChildBlocksElseBody().get(i);
StatementBlock sbc = istmt.getElseBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(ipb, isb);
} else if (// incl parfor
pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) {
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) sb.getStatement(0);
checkHopDagCorrectness(prog, dprog, fsb.getFromHops(), fpb.getFromInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getToHops(), fpb.getToInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getIncrementHops(), fpb.getIncrementInstructions(), fnStack);
for (int i = 0; i < fpb.getChildBlocks().size(); i++) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(fpb, fsb);
} else {
checkHopDagCorrectness(prog, dprog, sb.getHops(), pb.getInstructions(), fnStack);
// checkLinksProgramStatementBlock(pb, sb);
}
}
Aggregations