Search in sources :

Example 51 with DataIdentifier

use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.

the class PydmlSyntacticValidator method getFunctionParameters.

// -----------------------------------------------------------------
// Internal & External Functions Definitions
// -----------------------------------------------------------------
private ArrayList<DataIdentifier> getFunctionParameters(List<TypedArgNoAssignContext> ctx) {
    ArrayList<DataIdentifier> retVal = new ArrayList<>();
    for (TypedArgNoAssignContext paramCtx : ctx) {
        DataIdentifier dataId = new DataIdentifier(paramCtx.paramName.getText());
        String dataType = null;
        String valueType = null;
        if (paramCtx.paramType == null || paramCtx.paramType.dataType() == null || paramCtx.paramType.dataType().getText() == null || paramCtx.paramType.dataType().getText().isEmpty()) {
            dataType = "scalar";
        } else {
            dataType = paramCtx.paramType.dataType().getText();
        }
        // check and assign data type
        checkValidDataType(dataType, paramCtx.start);
        if (dataType.equals("matrix"))
            dataId.setDataType(DataType.MATRIX);
        else if (dataType.equals("frame"))
            dataId.setDataType(DataType.FRAME);
        else if (dataType.equals("scalar"))
            dataId.setDataType(DataType.SCALAR);
        valueType = paramCtx.paramType.valueType().getText();
        if (valueType.equals("int")) {
            dataId.setValueType(ValueType.INT);
        } else if (valueType.equals("str")) {
            dataId.setValueType(ValueType.STRING);
        } else if (valueType.equals("bool")) {
            dataId.setValueType(ValueType.BOOLEAN);
        } else if (valueType.equals("float")) {
            dataId.setValueType(ValueType.DOUBLE);
        } else {
            notifyErrorListeners("invalid valuetype " + valueType, paramCtx.start);
            return null;
        }
        retVal.add(dataId);
    }
    return retVal;
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) ArrayList(java.util.ArrayList) TypedArgNoAssignContext(org.apache.sysml.parser.pydml.PydmlParser.TypedArgNoAssignContext)

Example 52 with DataIdentifier

use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.

the class PydmlSyntacticValidator method exitIndexedExpression.

/**
 * PyDML uses 0-based indexing, so we increment lower indices by 1
 * when translating to DML.
 *
 * @param ctx the parse tree
 */
@Override
public void exitIndexedExpression(IndexedExpressionContext ctx) {
    boolean isRowLower = (ctx.rowLower != null && !ctx.rowLower.isEmpty() && (ctx.rowLower.info.expr != null));
    boolean isRowUpper = (ctx.rowUpper != null && !ctx.rowUpper.isEmpty() && (ctx.rowUpper.info.expr != null));
    boolean isColLower = (ctx.colLower != null && !ctx.colLower.isEmpty() && (ctx.colLower.info.expr != null));
    boolean isColUpper = (ctx.colUpper != null && !ctx.colUpper.isEmpty() && (ctx.colUpper.info.expr != null));
    boolean isRowSliceImplicit = ctx.rowImplicitSlice != null;
    boolean isColSliceImplicit = ctx.colImplicitSlice != null;
    ExpressionInfo rowLower = isRowLower ? ctx.rowLower.info : null;
    ExpressionInfo rowUpper = isRowUpper ? ctx.rowUpper.info : null;
    ExpressionInfo colLower = isColLower ? ctx.colLower.info : null;
    ExpressionInfo colUpper = isColUpper ? ctx.colUpper.info : null;
    ctx.dataInfo.expr = new IndexedIdentifier(ctx.name.getText(), false, false);
    setFileLineColumn(ctx.dataInfo.expr, ctx);
    try {
        ArrayList<ArrayList<Expression>> exprList = new ArrayList<>();
        ArrayList<Expression> rowIndices = new ArrayList<>();
        ArrayList<Expression> colIndices = new ArrayList<>();
        if (!isRowLower && !isRowUpper) {
            // both not set
            rowIndices.add(null);
            rowIndices.add(null);
        } else if (isRowLower && isRowUpper) {
            // both set
            rowIndices.add(incrementByOne(rowLower.expr, ctx));
            rowIndices.add(rowUpper.expr);
        } else if (isRowLower && !isRowUpper) {
            // Add given lower bound
            rowIndices.add(incrementByOne(rowLower.expr, ctx));
            if (isRowSliceImplicit) {
                // Add expression for nrow(X) for implicit upper bound
                Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NROW;
                DataIdentifier x = new DataIdentifier(ctx.name.getText());
                Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
                rowIndices.add(expr);
            }
        } else if (!isRowLower && isRowUpper && isRowSliceImplicit) {
            // Add expression for `1` for implicit lower bound
            // Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
            IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
            rowIndices.add(one);
            // Add given upper bound
            rowIndices.add(rowUpper.expr);
        } else {
            notifyErrorListeners("incorrect index expression for row", ctx.start);
            return;
        }
        if (!isColLower && !isColUpper) {
            // both not set
            colIndices.add(null);
            colIndices.add(null);
        } else if (isColLower && isColUpper) {
            colIndices.add(incrementByOne(colLower.expr, ctx));
            colIndices.add(colUpper.expr);
        } else if (isColLower && !isColUpper) {
            // Add given lower bound
            colIndices.add(incrementByOne(colLower.expr, ctx));
            if (isColSliceImplicit) {
                // Add expression for ncol(X) for implicit upper bound
                Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NCOL;
                DataIdentifier x = new DataIdentifier(ctx.name.getText());
                Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
                colIndices.add(expr);
            }
        } else if (!isColLower && isColUpper && isColSliceImplicit) {
            // Add expression for `1` for implicit lower bound
            // Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
            IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
            colIndices.add(one);
            // Add given upper bound
            colIndices.add(colUpper.expr);
        } else {
            notifyErrorListeners("incorrect index expression for column", ctx.start);
            return;
        }
        exprList.add(rowIndices);
        exprList.add(colIndices);
        ((IndexedIdentifier) ctx.dataInfo.expr).setIndices(exprList);
    } catch (Exception e) {
        notifyErrorListeners("cannot set the indices", ctx.start);
        return;
    }
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) BinaryExpression(org.apache.sysml.parser.BinaryExpression) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) IntIdentifier(org.apache.sysml.parser.IntIdentifier) ArrayList(java.util.ArrayList) ExpressionInfo(org.apache.sysml.parser.common.ExpressionInfo) LanguageException(org.apache.sysml.parser.LanguageException) ParseException(org.apache.sysml.parser.ParseException) IndexedIdentifier(org.apache.sysml.parser.IndexedIdentifier)

Example 53 with DataIdentifier

use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.

the class FunctionCallCPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    if (LOG.isTraceEnabled()) {
        LOG.trace("Executing instruction : " + this.toString());
    }
    // get the function program block (stored in the Program object)
    FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(_namespace, _functionName);
    // sanity check number of function parameters
    if (_boundInputs.length < fpb.getInputParams().size()) {
        throw new DMLRuntimeException("Number of bound input parameters does not match the function signature " + "(" + _boundInputs.length + ", but " + fpb.getInputParams().size() + " expected)");
    }
    // create bindings to formal parameters for given function call
    // These are the bindings passed to the FunctionProgramBlock for function execution
    LocalVariableMap functionVariables = new LocalVariableMap();
    for (int i = 0; i < fpb.getInputParams().size(); i++) {
        // error handling non-existing variables
        CPOperand input = _boundInputs[i];
        if (!input.isLiteral() && !ec.containsVariable(input.getName())) {
            throw new DMLRuntimeException("Input variable '" + input.getName() + "' not existing on call of " + DMLProgram.constructFunctionKey(_namespace, _functionName) + " (line " + getLineNum() + ").");
        }
        // get input matrix/frame/scalar
        DataIdentifier currFormalParam = fpb.getInputParams().get(i);
        Data value = ec.getVariable(input);
        // graceful value type conversion for scalar inputs with wrong type
        if (value.getDataType() == DataType.SCALAR && value.getValueType() != currFormalParam.getValueType()) {
            value = ScalarObjectFactory.createScalarObject(currFormalParam.getValueType(), (ScalarObject) value);
        }
        // set input parameter
        functionVariables.put(currFormalParam.getName(), value);
    }
    // Pin the input variables so that they do not get deleted
    // from pb's symbol table at the end of execution of function
    boolean[] pinStatus = ec.pinVariables(_boundInputNames);
    // Create a symbol table under a new execution context for the function invocation,
    // and copy the function arguments into the created table.
    ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
    if (DMLScript.USE_ACCELERATOR) {
        fn_ec.setGPUContexts(ec.getGPUContexts());
        fn_ec.getGPUContext(0).initializeThread();
    }
    fn_ec.setVariables(functionVariables);
    // execute the function block
    try {
        fpb._functionName = this._functionName;
        fpb._namespace = this._namespace;
        fpb.execute(fn_ec);
    } catch (DMLScriptException e) {
        throw e;
    } catch (Exception e) {
        String fname = DMLProgram.constructFunctionKey(_namespace, _functionName);
        throw new DMLRuntimeException("error executing function " + fname, e);
    }
    // cleanup all returned variables w/o binding
    HashSet<String> expectRetVars = new HashSet<>();
    for (DataIdentifier di : fpb.getOutputParams()) expectRetVars.add(di.getName());
    LocalVariableMap retVars = fn_ec.getVariables();
    for (Entry<String, Data> var : retVars.entrySet()) {
        if (expectRetVars.contains(var.getKey()))
            continue;
        // cleanup unexpected return values to avoid leaks
        if (var.getValue() instanceof CacheableData)
            fn_ec.cleanupCacheableData((CacheableData<?>) var.getValue());
    }
    // Unpin the pinned variables
    ec.unpinVariables(_boundInputNames, pinStatus);
    // add the updated binding for each return variable to the variables in original symbol table
    for (int i = 0; i < fpb.getOutputParams().size(); i++) {
        String boundVarName = _boundOutputNames.get(i);
        Data boundValue = retVars.get(fpb.getOutputParams().get(i).getName());
        if (boundValue == null)
            throw new DMLRuntimeException(boundVarName + " was not assigned a return value");
        // cleanup existing data bound to output variable name
        Data exdata = ec.removeVariable(boundVarName);
        if (exdata != null && exdata instanceof CacheableData && exdata != boundValue) {
            ec.cleanupCacheableData((CacheableData<?>) exdata);
        }
        // add/replace data in symbol table
        ec.setVariable(boundVarName, boundValue);
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) DataIdentifier(org.apache.sysml.parser.DataIdentifier) CacheableData(org.apache.sysml.runtime.controlprogram.caching.CacheableData) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLScriptException(org.apache.sysml.runtime.DMLScriptException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) ExecutionContext(org.apache.sysml.runtime.controlprogram.context.ExecutionContext) CacheableData(org.apache.sysml.runtime.controlprogram.caching.CacheableData) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) DMLScriptException(org.apache.sysml.runtime.DMLScriptException) HashSet(java.util.HashSet)

Example 54 with DataIdentifier

use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.

the class GenerateClassesForMLContext method createFunctionOutputClass.

/**
 * Create a class that encapsulates the outputs of a function.
 *
 * @param scriptFilePath
 *            the path to a script file
 * @param fs
 *            a SystemML function statement
 */
public static void createFunctionOutputClass(String scriptFilePath, FunctionStatement fs) {
    try {
        ArrayList<DataIdentifier> oparams = fs.getOutputParams();
        // than encapsulating it in a function output class
        if ((oparams.size() == 0) || (oparams.size() == 1)) {
            return;
        }
        String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
        System.out.println("Generating Class: " + fullFunctionOutputClassName);
        ClassPool pool = ClassPool.getDefault();
        CtClass ctFuncOut = pool.makeClass(fullFunctionOutputClassName);
        // add fields
        for (int i = 0; i < oparams.size(); i++) {
            DataIdentifier oparam = oparams.get(i);
            String type = getParamTypeAsString(oparam);
            String name = oparam.getName();
            String fstring = "public " + type + " " + name + ";";
            CtField field = CtField.make(fstring, ctFuncOut);
            ctFuncOut.addField(field);
        }
        // add constructor
        String simpleFuncOutClassName = fullFunctionOutputClassName.substring(fullFunctionOutputClassName.lastIndexOf(".") + 1);
        StringBuilder con = new StringBuilder();
        con.append("public " + simpleFuncOutClassName + "(");
        for (int i = 0; i < oparams.size(); i++) {
            if (i > 0) {
                con.append(", ");
            }
            DataIdentifier oparam = oparams.get(i);
            String type = getParamTypeAsString(oparam);
            String name = oparam.getName();
            con.append(type + " " + name);
        }
        con.append(") {\n");
        for (int i = 0; i < oparams.size(); i++) {
            DataIdentifier oparam = oparams.get(i);
            String name = oparam.getName();
            con.append("this." + name + "=" + name + ";\n");
        }
        con.append("}\n");
        String cstring = con.toString();
        CtConstructor ctCon = CtNewConstructor.make(cstring, ctFuncOut);
        ctFuncOut.addConstructor(ctCon);
        // add toString
        StringBuilder s = new StringBuilder();
        s.append("public String toString(){\n");
        s.append("StringBuilder sb = new StringBuilder();\n");
        for (int i = 0; i < oparams.size(); i++) {
            DataIdentifier oparam = oparams.get(i);
            String name = oparam.getName();
            s.append("sb.append(\"" + name + " (" + getSimpleParamTypeAsString(oparam) + "): \" + " + name + " + \"\\n\");\n");
        }
        s.append("String str = sb.toString();\nreturn str;\n");
        s.append("}\n");
        String toStr = s.toString();
        CtMethod toStrMethod = CtNewMethod.make(toStr, ctFuncOut);
        ctFuncOut.addMethod(toStrMethod);
        ctFuncOut.writeFile(destination);
    } catch (RuntimeException e) {
        e.printStackTrace();
    } catch (CannotCompileException e) {
        e.printStackTrace();
    } catch (IOException e) {
        e.printStackTrace();
    }
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) ClassPool(javassist.ClassPool) CannotCompileException(javassist.CannotCompileException) IOException(java.io.IOException) CtConstructor(javassist.CtConstructor) CtClass(javassist.CtClass) CtField(javassist.CtField) CtMethod(javassist.CtMethod)

Example 55 with DataIdentifier

use of org.apache.sysml.parser.DataIdentifier in project incubator-systemml by apache.

the class GenerateClassesForMLContext method generateFunctionCallMethod.

/**
 * Obtain method for invoking a script function.
 *
 * @param scriptFilePath
 *            the path to a script file
 * @param fs
 *            a SystemML function statement
 * @param dmlFunctionCall
 *            a string representing the invocation of a script function
 * @return string representation of a method that performs a function call
 */
public static String generateFunctionCallMethod(String scriptFilePath, FunctionStatement fs, String dmlFunctionCall) {
    createFunctionOutputClass(scriptFilePath, fs);
    StringBuilder sb = new StringBuilder();
    sb.append("public ");
    // begin return type
    ArrayList<DataIdentifier> oparams = fs.getOutputParams();
    if (oparams.size() == 0) {
        sb.append("void");
    } else if (oparams.size() == 1) {
        // if 1 output, no need to encapsulate it, so return the output
        // directly
        DataIdentifier oparam = oparams.get(0);
        String type = getParamTypeAsString(oparam);
        sb.append(type);
    } else {
        String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
        sb.append(fullFunctionOutputClassName);
    }
    sb.append(" ");
    // end return type
    sb.append(fs.getName());
    sb.append("(");
    ArrayList<DataIdentifier> inputParams = fs.getInputParams();
    for (int i = 0; i < inputParams.size(); i++) {
        if (i > 0) {
            sb.append(", ");
        }
        DataIdentifier inputParam = inputParams.get(i);
        /*
			 * Note: Using Object is currently preferrable to using
			 * datatype/valuetype to explicitly set the input type to
			 * Integer/Double/Boolean/String since Object allows the automatic
			 * handling of things such as automatic conversions from longs to
			 * ints.
			 */
        sb.append("Object ");
        sb.append(inputParam.getName());
    }
    sb.append(") {\n");
    sb.append("String scriptString = \"" + dmlFunctionCall + "\";\n");
    sb.append("org.apache.sysml.api.mlcontext.Script script = new org.apache.sysml.api.mlcontext.Script(scriptString);\n");
    if ((inputParams.size() > 0) || (oparams.size() > 0)) {
        sb.append("script");
    }
    for (int i = 0; i < inputParams.size(); i++) {
        DataIdentifier inputParam = inputParams.get(i);
        String name = inputParam.getName();
        sb.append(".in(\"" + name + "\", " + name + ")");
    }
    for (int i = 0; i < oparams.size(); i++) {
        DataIdentifier outputParam = oparams.get(i);
        String name = outputParam.getName();
        sb.append(".out(\"" + name + "\")");
    }
    if ((inputParams.size() > 0) || (oparams.size() > 0)) {
        sb.append(";\n");
    }
    sb.append("org.apache.sysml.api.mlcontext.MLResults results = script.execute();\n");
    if (oparams.size() == 0) {
        sb.append("return;\n");
    } else if (oparams.size() == 1) {
        DataIdentifier o = oparams.get(0);
        DataType dt = o.getDataType();
        ValueType vt = o.getValueType();
        if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
            sb.append("long res = results.getLong(\"" + o.getName() + "\");\nreturn res;\n");
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
            sb.append("double res = results.getDouble(\"" + o.getName() + "\");\nreturn res;\n");
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
            sb.append("boolean res = results.getBoolean(\"" + o.getName() + "\");\nreturn res;\n");
        } else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
            sb.append("String res = results.getString(\"" + o.getName() + "\");\nreturn res;\n");
        } else if (dt == DataType.MATRIX) {
            sb.append("org.apache.sysml.api.mlcontext.Matrix res = results.getMatrix(\"" + o.getName() + "\");\nreturn res;\n");
        } else if (dt == DataType.FRAME) {
            sb.append("org.apache.sysml.api.mlcontext.Frame res = results.getFrame(\"" + o.getName() + "\");\nreturn res;\n");
        }
    } else {
        for (int i = 0; i < oparams.size(); i++) {
            DataIdentifier outputParam = oparams.get(i);
            String name = outputParam.getName().toLowerCase();
            String type = getParamTypeAsString(outputParam);
            DataType dt = outputParam.getDataType();
            ValueType vt = outputParam.getValueType();
            if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
                sb.append(type + " " + name + " = results.getLong(\"" + outputParam.getName() + "\");\n");
            } else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
                sb.append(type + " " + name + " = results.getDouble(\"" + outputParam.getName() + "\");\n");
            } else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
                sb.append(type + " " + name + " = results.getBoolean(\"" + outputParam.getName() + "\");\n");
            } else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
                sb.append(type + " " + name + " = results.getString(\"" + outputParam.getName() + "\");\n");
            } else if (dt == DataType.MATRIX) {
                sb.append(type + " " + name + " = results.getMatrix(\"" + outputParam.getName() + "\");\n");
            } else if (dt == DataType.FRAME) {
                sb.append(type + " " + name + " = results.getFrame(\"" + outputParam.getName() + "\");\n");
            }
        }
        String ffocn = getFullFunctionOutputClassName(scriptFilePath, fs);
        sb.append(ffocn + " res = new " + ffocn + "(");
        for (int i = 0; i < oparams.size(); i++) {
            if (i > 0) {
                sb.append(", ");
            }
            DataIdentifier outputParam = oparams.get(i);
            String name = outputParam.getName().toLowerCase();
            sb.append(name);
        }
        sb.append(");\nreturn res;\n");
    }
    sb.append("}\n");
    return sb.toString();
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) ValueType(org.apache.sysml.parser.Expression.ValueType) DataType(org.apache.sysml.parser.Expression.DataType)

Aggregations

DataIdentifier (org.apache.sysml.parser.DataIdentifier)56 ArrayList (java.util.ArrayList)19 HashMap (java.util.HashMap)13 ParameterExpression (org.apache.sysml.parser.ParameterExpression)13 Hop (org.apache.sysml.hops.Hop)12 Expression (org.apache.sysml.parser.Expression)12 LiteralOp (org.apache.sysml.hops.LiteralOp)10 BinaryExpression (org.apache.sysml.parser.BinaryExpression)8 BuiltinFunctionExpression (org.apache.sysml.parser.BuiltinFunctionExpression)8 Data (org.apache.sysml.runtime.instructions.cp.Data)7 DataGenOp (org.apache.sysml.hops.DataGenOp)6 DataOp (org.apache.sysml.hops.DataOp)6 StatementBlock (org.apache.sysml.parser.StatementBlock)6 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)6 HopsException (org.apache.sysml.hops.HopsException)4 DataType (org.apache.sysml.parser.Expression.DataType)4 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)4 IterablePredicate (org.apache.sysml.parser.IterablePredicate)4 LanguageException (org.apache.sysml.parser.LanguageException)4 ParForStatement (org.apache.sysml.parser.ParForStatement)4