use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.
the class PydmlSyntacticValidator method getFunctionParameters.
// -----------------------------------------------------------------
// Internal & External Functions Definitions
// -----------------------------------------------------------------
private ArrayList<DataIdentifier> getFunctionParameters(List<TypedArgNoAssignContext> ctx) {
ArrayList<DataIdentifier> retVal = new ArrayList<>();
for (TypedArgNoAssignContext paramCtx : ctx) {
DataIdentifier dataId = new DataIdentifier(paramCtx.paramName.getText());
String dataType = null;
String valueType = null;
if (paramCtx.paramType == null || paramCtx.paramType.dataType() == null || paramCtx.paramType.dataType().getText() == null || paramCtx.paramType.dataType().getText().isEmpty()) {
dataType = "scalar";
} else {
dataType = paramCtx.paramType.dataType().getText();
}
// check and assign data type
checkValidDataType(dataType, paramCtx.start);
if (dataType.equals("matrix"))
dataId.setDataType(DataType.MATRIX);
else if (dataType.equals("frame"))
dataId.setDataType(DataType.FRAME);
else if (dataType.equals("scalar"))
dataId.setDataType(DataType.SCALAR);
valueType = paramCtx.paramType.valueType().getText();
if (valueType.equals("int")) {
dataId.setValueType(ValueType.INT);
} else if (valueType.equals("str")) {
dataId.setValueType(ValueType.STRING);
} else if (valueType.equals("bool")) {
dataId.setValueType(ValueType.BOOLEAN);
} else if (valueType.equals("float")) {
dataId.setValueType(ValueType.DOUBLE);
} else {
notifyErrorListeners("invalid valuetype " + valueType, paramCtx.start);
return null;
}
retVal.add(dataId);
}
return retVal;
}
use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.
the class PydmlSyntacticValidator method exitIndexedExpression.
/**
* PyDML uses 0-based indexing, so we increment lower indices by 1
* when translating to DML.
*
* @param ctx the parse tree
*/
@Override
public void exitIndexedExpression(IndexedExpressionContext ctx) {
boolean isRowLower = (ctx.rowLower != null && !ctx.rowLower.isEmpty() && (ctx.rowLower.info.expr != null));
boolean isRowUpper = (ctx.rowUpper != null && !ctx.rowUpper.isEmpty() && (ctx.rowUpper.info.expr != null));
boolean isColLower = (ctx.colLower != null && !ctx.colLower.isEmpty() && (ctx.colLower.info.expr != null));
boolean isColUpper = (ctx.colUpper != null && !ctx.colUpper.isEmpty() && (ctx.colUpper.info.expr != null));
boolean isRowSliceImplicit = ctx.rowImplicitSlice != null;
boolean isColSliceImplicit = ctx.colImplicitSlice != null;
ExpressionInfo rowLower = isRowLower ? ctx.rowLower.info : null;
ExpressionInfo rowUpper = isRowUpper ? ctx.rowUpper.info : null;
ExpressionInfo colLower = isColLower ? ctx.colLower.info : null;
ExpressionInfo colUpper = isColUpper ? ctx.colUpper.info : null;
ctx.dataInfo.expr = new IndexedIdentifier(ctx.name.getText(), false, false);
setFileLineColumn(ctx.dataInfo.expr, ctx);
try {
ArrayList<ArrayList<Expression>> exprList = new ArrayList<>();
ArrayList<Expression> rowIndices = new ArrayList<>();
ArrayList<Expression> colIndices = new ArrayList<>();
if (!isRowLower && !isRowUpper) {
// both not set
rowIndices.add(null);
rowIndices.add(null);
} else if (isRowLower && isRowUpper) {
// both set
rowIndices.add(incrementByOne(rowLower.expr, ctx));
rowIndices.add(rowUpper.expr);
} else if (isRowLower && !isRowUpper) {
// Add given lower bound
rowIndices.add(incrementByOne(rowLower.expr, ctx));
if (isRowSliceImplicit) {
// Add expression for nrow(X) for implicit upper bound
Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NROW;
DataIdentifier x = new DataIdentifier(ctx.name.getText());
Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
rowIndices.add(expr);
}
} else if (!isRowLower && isRowUpper && isRowSliceImplicit) {
// Add expression for `1` for implicit lower bound
// Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
rowIndices.add(one);
// Add given upper bound
rowIndices.add(rowUpper.expr);
} else {
notifyErrorListeners("incorrect index expression for row", ctx.start);
return;
}
if (!isColLower && !isColUpper) {
// both not set
colIndices.add(null);
colIndices.add(null);
} else if (isColLower && isColUpper) {
colIndices.add(incrementByOne(colLower.expr, ctx));
colIndices.add(colUpper.expr);
} else if (isColLower && !isColUpper) {
// Add given lower bound
colIndices.add(incrementByOne(colLower.expr, ctx));
if (isColSliceImplicit) {
// Add expression for ncol(X) for implicit upper bound
Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NCOL;
DataIdentifier x = new DataIdentifier(ctx.name.getText());
Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
colIndices.add(expr);
}
} else if (!isColLower && isColUpper && isColSliceImplicit) {
// Add expression for `1` for implicit lower bound
// Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
colIndices.add(one);
// Add given upper bound
colIndices.add(colUpper.expr);
} else {
notifyErrorListeners("incorrect index expression for column", ctx.start);
return;
}
exprList.add(rowIndices);
exprList.add(colIndices);
((IndexedIdentifier) ctx.dataInfo.expr).setIndices(exprList);
} catch (Exception e) {
notifyErrorListeners("cannot set the indices", ctx.start);
return;
}
}
use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.
the class FunctionCallCPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) {
if (LOG.isTraceEnabled()) {
LOG.trace("Executing instruction : " + this.toString());
}
// get the function program block (stored in the Program object)
FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
// sanity check number of function parameters
if (_boundInputs.length < fpb.getInputParams().size()) {
throw new DMLRuntimeException("Number of bound input parameters does not match the function signature " + "(" + _boundInputs.length + ", but " + fpb.getInputParams().size() + " expected)");
}
// create bindings to formal parameters for given function call
// These are the bindings passed to the FunctionProgramBlock for function execution
LocalVariableMap functionVariables = new LocalVariableMap();
for (int i = 0; i < fpb.getInputParams().size(); i++) {
// error handling non-existing variables
CPOperand input = _boundInputs[i];
if (!input.isLiteral() && !ec.containsVariable(input.getName())) {
throw new DMLRuntimeException("Input variable '" + input.getName() + "' not existing on call of " + DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line " + getLineNum() + ").");
}
// get input matrix/frame/scalar
DataIdentifier currFormalParam = fpb.getInputParams().get(i);
Data value = ec.getVariable(input);
// graceful value type conversion for scalar inputs with wrong type
if (value.getDataType() == DataType.SCALAR && value.getValueType() != currFormalParam.getValueType()) {
value = ScalarObjectFactory.createScalarObject(currFormalParam.getValueType(), (ScalarObject) value);
}
// set input parameter
functionVariables.put(currFormalParam.getName(), value);
}
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
boolean[] pinStatus = ec.pinVariables(_boundInputNames);
// Create a symbol table under a new execution context for the function invocation,
// and copy the function arguments into the created table.
ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
if (DMLScript.USE_ACCELERATOR) {
fn_ec.setGPUContexts(ec.getGPUContexts());
fn_ec.getGPUContext(0).initializeThread();
}
fn_ec.setVariables(functionVariables);
// execute the function block
try {
fpb._functionName = this._functionName;
fpb._namespace = this._namespace;
fpb.execute(fn_ec);
} catch (DMLScriptException e) {
throw e;
} catch (Exception e) {
String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
throw new DMLRuntimeException("error executing function " + fname, e);
}
// cleanup all returned variables w/o binding
HashSet<String> expectRetVars = new HashSet<>();
for (DataIdentifier di : fpb.getOutputParams()) expectRetVars.add(di.getName());
LocalVariableMap retVars = fn_ec.getVariables();
for (Entry<String, Data> var : retVars.entrySet()) {
if (expectRetVars.contains(var.getKey()))
continue;
// cleanup unexpected return values to avoid leaks
if (var.getValue() instanceof CacheableData)
fn_ec.cleanupCacheableData((CacheableData<?>) var.getValue());
}
// Unpin the pinned variables
ec.unpinVariables(_boundInputNames, pinStatus);
// add the updated binding for each return variable to the variables in original symbol table
for (int i = 0; i < fpb.getOutputParams().size(); i++) {
String boundVarName = _boundOutputNames.get(i);
Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
if (boundValue == null)
throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
// cleanup existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if (exdata != null && exdata instanceof CacheableData && exdata != boundValue) {
ec.cleanupCacheableData((CacheableData<?>) exdata);
}
// add/replace data in symbol table
ec.setVariable(boundVarName, boundValue);
}
}
use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.
the class GenerateClassesForMLContext method createFunctionOutputClass.
/**
* Create a class that encapsulates the outputs of a function.
*
* @param scriptFilePath
* the path to a script file
* @param fs
* a SystemML function statement
*/
public static void createFunctionOutputClass(String scriptFilePath, FunctionStatement fs) {
try {
ArrayList<DataIdentifier> oparams = fs.getOutputParams();
// than encapsulating it in a function output class
if ((oparams.size() == 0) || (oparams.size() == 1)) {
return;
}
String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
System.out.println("Generating Class: " + fullFunctionOutputClassName);
ClassPool pool = ClassPool.getDefault();
CtClass ctFuncOut = pool.makeClass(fullFunctionOutputClassName);
// add fields
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier oparam = oparams.get(i);
String type = getParamTypeAsString(oparam);
String name = oparam.getName();
String fstring = "public " + type + " " + name + ";";
CtField field = CtField.make(fstring, ctFuncOut);
ctFuncOut.addField(field);
}
// add constructor
String simpleFuncOutClassName = fullFunctionOutputClassName.substring(fullFunctionOutputClassName.lastIndexOf(".") + 1);
StringBuilder con = new StringBuilder();
con.append("public " + simpleFuncOutClassName + "(");
for (int i = 0; i < oparams.size(); i++) {
if (i > 0) {
con.append(", ");
}
DataIdentifier oparam = oparams.get(i);
String type = getParamTypeAsString(oparam);
String name = oparam.getName();
con.append(type + " " + name);
}
con.append(") {\n");
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier oparam = oparams.get(i);
String name = oparam.getName();
con.append("this." + name + "=" + name + ";\n");
}
con.append("}\n");
String cstring = con.toString();
CtConstructor ctCon = CtNewConstructor.make(cstring, ctFuncOut);
ctFuncOut.addConstructor(ctCon);
// add toString
StringBuilder s = new StringBuilder();
s.append("public String toString(){\n");
s.append("StringBuilder sb = new StringBuilder();\n");
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier oparam = oparams.get(i);
String name = oparam.getName();
s.append("sb.append(\"" + name + " (" + getSimpleParamTypeAsString(oparam) + "): \" + " + name + " + \"\\n\");\n");
}
s.append("String str = sb.toString();\nreturn str;\n");
s.append("}\n");
String toStr = s.toString();
CtMethod toStrMethod = CtNewMethod.make(toStr, ctFuncOut);
ctFuncOut.addMethod(toStrMethod);
ctFuncOut.writeFile(destination);
} catch (RuntimeException e) {
e.printStackTrace();
} catch (CannotCompileException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.
the class GenerateClassesForMLContext method generateFunctionCallMethod.
/**
* Obtain method for invoking a script function.
*
* @param scriptFilePath
* the path to a script file
* @param fs
* a SystemML function statement
* @param dmlFunctionCall
* a string representing the invocation of a script function
* @return string representation of a method that performs a function call
*/
public static String generateFunctionCallMethod(String scriptFilePath, FunctionStatement fs, String dmlFunctionCall) {
createFunctionOutputClass(scriptFilePath, fs);
StringBuilder sb = new StringBuilder();
sb.append("public ");
// begin return type
ArrayList<DataIdentifier> oparams = fs.getOutputParams();
if (oparams.size() == 0) {
sb.append("void");
} else if (oparams.size() == 1) {
// if 1 output, no need to encapsulate it, so return the output
// directly
DataIdentifier oparam = oparams.get(0);
String type = getParamTypeAsString(oparam);
sb.append(type);
} else {
String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
sb.append(fullFunctionOutputClassName);
}
sb.append(" ");
// end return type
sb.append(fs.getName());
sb.append("(");
ArrayList<DataIdentifier> inputParams = fs.getInputParams();
for (int i = 0; i < inputParams.size(); i++) {
if (i > 0) {
sb.append(", ");
}
DataIdentifier inputParam = inputParams.get(i);
/*
* Note: Using Object is currently preferrable to using
* datatype/valuetype to explicitly set the input type to
* Integer/Double/Boolean/String since Object allows the automatic
* handling of things such as automatic conversions from longs to
* ints.
*/
sb.append("Object ");
sb.append(inputParam.getName());
}
sb.append(") {\n");
sb.append("String scriptString = \"" + dmlFunctionCall + "\";\n");
sb.append("org.apache.sysml.api.mlcontext.Script script = new org.apache.sysml.api.mlcontext.Script(scriptString);\n");
if ((inputParams.size() > 0) || (oparams.size() > 0)) {
sb.append("script");
}
for (int i = 0; i < inputParams.size(); i++) {
DataIdentifier inputParam = inputParams.get(i);
String name = inputParam.getName();
sb.append(".in(\"" + name + "\", " + name + ")");
}
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier outputParam = oparams.get(i);
String name = outputParam.getName();
sb.append(".out(\"" + name + "\")");
}
if ((inputParams.size() > 0) || (oparams.size() > 0)) {
sb.append(";\n");
}
sb.append("org.apache.sysml.api.mlcontext.MLResults results = script.execute();\n");
if (oparams.size() == 0) {
sb.append("return;\n");
} else if (oparams.size() == 1) {
DataIdentifier o = oparams.get(0);
DataType dt = o.getDataType();
ValueType vt = o.getValueType();
if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
sb.append("long res = results.getLong(\"" + o.getName() + "\");\nreturn res;\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
sb.append("double res = results.getDouble(\"" + o.getName() + "\");\nreturn res;\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
sb.append("boolean res = results.getBoolean(\"" + o.getName() + "\");\nreturn res;\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
sb.append("String res = results.getString(\"" + o.getName() + "\");\nreturn res;\n");
} else if (dt == DataType.MATRIX) {
sb.append("org.apache.sysml.api.mlcontext.Matrix res = results.getMatrix(\"" + o.getName() + "\");\nreturn res;\n");
} else if (dt == DataType.FRAME) {
sb.append("org.apache.sysml.api.mlcontext.Frame res = results.getFrame(\"" + o.getName() + "\");\nreturn res;\n");
}
} else {
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier outputParam = oparams.get(i);
String name = outputParam.getName().toLowerCase();
String type = getParamTypeAsString(outputParam);
DataType dt = outputParam.getDataType();
ValueType vt = outputParam.getValueType();
if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
sb.append(type + " " + name + " = results.getLong(\"" + outputParam.getName() + "\");\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
sb.append(type + " " + name + " = results.getDouble(\"" + outputParam.getName() + "\");\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
sb.append(type + " " + name + " = results.getBoolean(\"" + outputParam.getName() + "\");\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
sb.append(type + " " + name + " = results.getString(\"" + outputParam.getName() + "\");\n");
} else if (dt == DataType.MATRIX) {
sb.append(type + " " + name + " = results.getMatrix(\"" + outputParam.getName() + "\");\n");
} else if (dt == DataType.FRAME) {
sb.append(type + " " + name + " = results.getFrame(\"" + outputParam.getName() + "\");\n");
}
}
String ffocn = getFullFunctionOutputClassName(scriptFilePath, fs);
sb.append(ffocn + " res = new " + ffocn + "(");
for (int i = 0; i < oparams.size(); i++) {
if (i > 0) {
sb.append(", ");
}
DataIdentifier outputParam = oparams.get(i);
String name = outputParam.getName().toLowerCase();
sb.append(name);
}
sb.append(");\nreturn res;\n");
}
sb.append("}\n");
return sb.toString();
}
Aggregations