Search in sources :

Example 1 with ParameterExpression

use of org.apache.sysml.parser.ParameterExpression in project incubator-systemml by apache.

the class PydmlSyntacticValidator method convertPythonBuiltinFunctionToDMLSyntax.

// TODO : Clean up to use Map or some other structure
/**
	 * Check function name, namespace, parameters (#params & possible values) and produce useful messages/hints
	 * @param ctx antlr rule context
	 * @param namespace Namespace of the function
	 * @param functionName Name of the builtin function
	 * @param paramExpression Array of parameter names and values
	 * @param fnName Token of the builtin function identifier
	 * @return common syntax format for runtime
	 */
private ConvertedDMLSyntax convertPythonBuiltinFunctionToDMLSyntax(ParserRuleContext ctx, String namespace, String functionName, ArrayList<ParameterExpression> paramExpression, Token fnName) {
    if (sources.containsValue(namespace) || functions.contains(functionName)) {
        return new ConvertedDMLSyntax(namespace, functionName, paramExpression);
    }
    String fileName = currentFile;
    int line = ctx.start.getLine();
    int col = ctx.start.getCharPositionInLine();
    if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("len")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 1 arguments", fnName);
            return null;
        }
        functionName = "length";
    } else if (functionName.equals("sum") || functionName.equals("mean") || functionName.equals("avg") || functionName.equals("min") || functionName.equals("max") || functionName.equals("argmax") || functionName.equals("argmin") || functionName.equals("cumsum") || functionName.equals("transpose") || functionName.equals("trace") || functionName.equals("var") || functionName.equals("sd")) {
        // can mean sum of all cells or row-wise or columnwise sum
        if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && paramExpression.size() == 1) {
            // otherwise same function name
            if (functionName.equals("avg")) {
                functionName = "mean";
            } else if (functionName.equals("transpose")) {
                functionName = "t";
            } else if (functionName.equals("argmax") || functionName.equals("argmin") || functionName.equals("cumsum")) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for entire matrix is not supported", fnName);
                return null;
            }
        } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && paramExpression.size() == 0) {
            // x.sum() => sum(x)
            paramExpression = new ArrayList<ParameterExpression>();
            paramExpression.add(new ParameterExpression(null, new DataIdentifier(namespace)));
            // otherwise same function name
            if (functionName.equals("avg")) {
                functionName = "mean";
            } else if (functionName.equals("transpose")) {
                functionName = "t";
            } else if (functionName.equals("argmax") || functionName.equals("argmin") || functionName.equals("cumsum")) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for entire matrix is not supported", fnName);
                return null;
            }
        } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && paramExpression.size() == 2) {
            // sum(x, axis=1) => rowSums(x)
            int axis = getAxis(paramExpression.get(1));
            if (axis == -1 && (functionName.equals("min") || functionName.equals("max"))) {
            // Do nothing
            // min(2, 3)
            } else if (axis == -1) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                return null;
            } else {
                ArrayList<ParameterExpression> temp = new ArrayList<ParameterExpression>();
                temp.add(paramExpression.get(0));
                paramExpression = temp;
                functionName = getPythonAggFunctionNames(functionName, axis);
                if (functionName.equals("Not Supported")) {
                    notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                    return null;
                }
            }
        } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && paramExpression.size() == 1) {
            // x.sum(axis=1) => rowSums(x)
            int axis = getAxis(paramExpression.get(0));
            if (axis == -1) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                return null;
            } else {
                paramExpression = new ArrayList<ParameterExpression>();
                paramExpression.add(new ParameterExpression(null, new DataIdentifier(namespace)));
                functionName = getPythonAggFunctionNames(functionName, axis);
                if (functionName.equals("Not Supported")) {
                    notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                    return null;
                }
            }
        } else {
            notifyErrorListeners("Incorrect number of arguments for the builtin function \'" + functionName + "\'.", fnName);
            return null;
        }
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("concatenate")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 2 arguments (Note: concatenate append columns of two matrices)", fnName);
            return null;
        }
        functionName = "append";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("minimum")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 2 arguments", fnName);
            return null;
        }
        functionName = "min";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("maximum")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 2 arguments", fnName);
            return null;
        }
        functionName = "max";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && functionName.equals("shape")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts only 1 argument (0 or 1)", fnName);
            return null;
        }
        int axis = getAxis(paramExpression.get(0));
        if (axis == -1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts only 1 argument (0 or 1)", fnName);
            return null;
        }
        paramExpression = new ArrayList<ParameterExpression>();
        paramExpression.add(new ParameterExpression(null, new DataIdentifier(namespace)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
        if (axis == 0) {
            functionName = "nrow";
        } else if (axis == 1) {
            functionName = "ncol";
        }
    } else if (namespace.equals("random") && functionName.equals("normal")) {
        if (paramExpression.size() != 3) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (number of rows, number of columns, sparsity)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        paramExpression.get(2).setName("sparsity");
        paramExpression.add(new ParameterExpression("pdf", new StringIdentifier("normal", fileName, line, col, line, col)));
        functionName = "rand";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("random") && functionName.equals("poisson")) {
        if (paramExpression.size() != 4) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (number of rows, number of columns, sparsity, lambda)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        paramExpression.get(2).setName("sparsity");
        paramExpression.get(3).setName("lambda");
        paramExpression.add(new ParameterExpression("pdf", new StringIdentifier("poisson", fileName, line, col, line, col)));
        functionName = "rand";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("random") && functionName.equals("uniform")) {
        if (paramExpression.size() != 5) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 5 arguments (number of rows, number of columns, sparsity, min, max)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        paramExpression.get(2).setName("sparsity");
        paramExpression.get(3).setName("min");
        paramExpression.get(4).setName("max");
        paramExpression.add(new ParameterExpression("pdf", new StringIdentifier("uniform", fileName, line, col, line, col)));
        functionName = "rand";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("full")) {
        if (paramExpression.size() != 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 3 arguments (constant float value, number of rows, number of columns)", fnName);
            return null;
        }
        paramExpression.get(1).setName("rows");
        paramExpression.get(2).setName("cols");
        functionName = "matrix";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("matrix")) {
        // This can either be string initializer or as.matrix function
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument (either str or float value)", fnName);
            return null;
        }
        if (paramExpression.get(0).getExpr() instanceof StringIdentifier) {
            String initializerString = ((StringIdentifier) paramExpression.get(0).getExpr()).getValue().trim();
            if (!initializerString.startsWith("[") || !initializerString.endsWith("]")) {
                notifyErrorListeners("Incorrect initializer string for builtin function \'" + functionName + "\' (Eg: matrix(\"[1 2 3; 4 5 6]\"))", fnName);
                return null;
            }
            int rows = StringUtils.countMatches(initializerString, ";") + 1;
            // Make sure user doesnot have pretty string
            initializerString = initializerString.replaceAll("; ", ";");
            initializerString = initializerString.replaceAll(" ;", ";");
            initializerString = initializerString.replaceAll("\\[ ", "\\[");
            initializerString = initializerString.replaceAll(" \\]", "\\]");
            // Each row has ncol-1 spaces
            // #spaces = nrow * (ncol-1)
            // ncol = (#spaces / nrow) + 1
            int cols = (StringUtils.countMatches(initializerString, " ") / rows) + 1;
            initializerString = initializerString.replaceAll(";", " ");
            initializerString = initializerString.replaceAll("\\[", "");
            initializerString = initializerString.replaceAll("\\]", "");
            paramExpression = new ArrayList<ParameterExpression>();
            paramExpression.add(new ParameterExpression(null, new StringIdentifier(initializerString, fileName, line, col, line, col)));
            paramExpression.add(new ParameterExpression("rows", new IntIdentifier(rows, fileName, line, col, line, col)));
            paramExpression.add(new ParameterExpression("cols", new IntIdentifier(cols, fileName, line, col, line, col)));
        } else {
            functionName = "as.matrix";
        }
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("scalar")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.scalar";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("float")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.double";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("int")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.integer";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("bool")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.logical";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && functionName.equals("reshape")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments (number of rows, number of columns)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        ArrayList<ParameterExpression> temp = new ArrayList<ParameterExpression>();
        temp.add(new ParameterExpression(null, new DataIdentifier(namespace)));
        temp.add(paramExpression.get(0));
        temp.add(paramExpression.get(1));
        paramExpression = temp;
        functionName = "matrix";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("removeEmpty")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments (matrix, axis=0 or 1)", fnName);
            return null;
        }
        int axis = getAxis(paramExpression.get(1));
        if (axis == -1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments (matrix, axis=0 or 1)", fnName);
            return null;
        }
        StringIdentifier marginVal = null;
        if (axis == 0) {
            marginVal = new StringIdentifier("rows", fileName, line, col, line, col);
        } else {
            marginVal = new StringIdentifier("cols", fileName, line, col, line, col);
        }
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("margin");
        paramExpression.get(1).setExpr(marginVal);
        functionName = "removeEmpty";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("replace")) {
        if (paramExpression.size() != 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 3 arguments (matrix, scalar value that should be replaced (pattern), scalar value (replacement))", fnName);
            return null;
        }
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("pattern");
        paramExpression.get(2).setName("replacement");
        functionName = "replace";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("range")) {
        if (paramExpression.size() < 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 3 arguments (from, to, increment), with the first 2 lacking default values", fnName);
            return null;
        } else if (paramExpression.size() > 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 3 arguments (from, to, increment)", fnName);
        }
        functionName = "seq";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("norm") && functionName.equals("cdf")) {
        if (paramExpression.size() != 3) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (target, mean, sd)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("mean");
        paramExpression.get(2).setName("sd");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier("normal", fileName, line, col, line, col)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("expon") && functionName.equals("cdf")) {
        if (paramExpression.size() != 2) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 2 arguments (target, mean)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("mean");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier("exp", fileName, line, col, line, col)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("chi") && functionName.equals("cdf")) {
        if (paramExpression.size() != 2) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 2 arguments (target, df)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("df");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier("chisq", fileName, line, col, line, col)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("f") && functionName.equals("cdf")) {
        if (paramExpression.size() != 3) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (target, df1, df2)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("df1");
        paramExpression.get(2).setName("df2");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier("f", fileName, line, col, line, col)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("t") && functionName.equals("cdf")) {
        if (paramExpression.size() != 2) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 2 arguments (target, df)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("df");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier("t", fileName, line, col, line, col)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("percentile")) {
        if (paramExpression.size() != 2 && paramExpression.size() != 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts either 2 or 3 arguments", fnName);
            return null;
        }
        functionName = "quantile";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("arcsin")) {
        functionName = "asin";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("arccos")) {
        functionName = "acos";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("arctan")) {
        functionName = "atan";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("load")) {
        functionName = "read";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("eigen")) {
        functionName = "eig";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("power")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments", fnName);
            return null;
        }
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("dot")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments", fnName);
            return null;
        }
    }
    return new ConvertedDMLSyntax(namespace, functionName, paramExpression);
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) IntIdentifier(org.apache.sysml.parser.IntIdentifier) StringIdentifier(org.apache.sysml.parser.StringIdentifier) ParameterExpression(org.apache.sysml.parser.ParameterExpression) ArrayList(java.util.ArrayList)

Example 2 with ParameterExpression

use of org.apache.sysml.parser.ParameterExpression in project incubator-systemml by apache.

the class CommonSyntacticValidator method functionCallAssignmentStatementHelper.

protected void functionCallAssignmentStatementHelper(final ParserRuleContext ctx, Set<String> printStatements, Set<String> outputStatements, final Expression dataInfo, final StatementInfo info, final Token nameToken, Token targetListToken, String namespace, String functionName, ArrayList<ParameterExpression> paramExpression, boolean hasLHS) {
    ConvertedDMLSyntax convertedSyntax = convertToDMLSyntax(ctx, namespace, functionName, paramExpression, nameToken);
    if (convertedSyntax == null) {
        return;
    } else {
        namespace = convertedSyntax.namespace;
        functionName = convertedSyntax.functionName;
        paramExpression = convertedSyntax.paramExpression;
    }
    // For builtin functions without LHS
    if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && !functions.contains(functionName)) {
        if (printStatements.contains(functionName)) {
            setPrintStatement(ctx, functionName, paramExpression, info);
            return;
        } else if (outputStatements.contains(functionName)) {
            setOutputStatement(ctx, paramExpression, info);
            return;
        }
    }
    DataIdentifier target = null;
    if (dataInfo instanceof DataIdentifier) {
        target = (DataIdentifier) dataInfo;
    } else if (dataInfo != null) {
        notifyErrorListeners("incorrect lvalue for function call ", targetListToken);
        return;
    }
    // For builtin functions with LHS
    if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && !functions.contains(functionName)) {
        final DataIdentifier ftarget = target;
        Action f = new Action() {

            @Override
            public void execute(Expression e) {
                setAssignmentStatement(ctx, info, ftarget, e);
            }
        };
        boolean validBIF = buildForBuiltInFunction(ctx, functionName, paramExpression, f);
        if (validBIF)
            return;
    }
    // If builtin functions weren't found...
    FunctionCallIdentifier functCall = new FunctionCallIdentifier(paramExpression);
    functCall.setFunctionName(functionName);
    // Override default namespace for imported non-built-in function
    String inferNamespace = (sourceNamespace != null && sourceNamespace.length() > 0 && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) ? sourceNamespace : namespace;
    functCall.setFunctionNamespace(inferNamespace);
    functCall.setAllPositions(currentFile, ctx.start.getLine(), ctx.start.getCharPositionInLine(), ctx.stop.getLine(), ctx.stop.getCharPositionInLine());
    setAssignmentStatement(ctx, info, target, functCall);
}
Also used : FunctionCallIdentifier(org.apache.sysml.parser.FunctionCallIdentifier) DataIdentifier(org.apache.sysml.parser.DataIdentifier) RelationalExpression(org.apache.sysml.parser.RelationalExpression) BooleanExpression(org.apache.sysml.parser.BooleanExpression) ParameterizedBuiltinFunctionExpression(org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) BinaryExpression(org.apache.sysml.parser.BinaryExpression) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) DataExpression(org.apache.sysml.parser.DataExpression)

Example 3 with ParameterExpression

use of org.apache.sysml.parser.ParameterExpression in project incubator-systemml by apache.

the class CommonSyntacticValidator method setOutputStatement.

protected void setOutputStatement(ParserRuleContext ctx, ArrayList<ParameterExpression> paramExpression, StatementInfo info) {
    if (paramExpression.size() < 2) {
        notifyErrorListeners("incorrect usage of write function (at least 2 arguments required)", ctx.start);
        return;
    }
    if (paramExpression.get(0).getExpr() instanceof DataIdentifier) {
        String fileName = currentFile;
        int line = ctx.start.getLine();
        int col = ctx.start.getCharPositionInLine();
        HashMap<String, Expression> varParams = new HashMap<String, Expression>();
        varParams.put(DataExpression.IO_FILENAME, paramExpression.get(1).getExpr());
        for (int i = 2; i < paramExpression.size(); i++) {
            // DataExpression.FORMAT_TYPE, DataExpression.DELIM_DELIMITER, DataExpression.DELIM_HAS_HEADER_ROW,  DataExpression.DELIM_SPARSE
            varParams.put(paramExpression.get(i).getName(), paramExpression.get(i).getExpr());
        }
        DataExpression dataExpression = new DataExpression(DataOp.WRITE, varParams, fileName, line, col, line, col);
        info.stmt = new OutputStatement((DataIdentifier) paramExpression.get(0).getExpr(), DataOp.WRITE, fileName, line, col, line, col);
        setFileLineColumn(info.stmt, ctx);
        ((OutputStatement) info.stmt).setExprParams(dataExpression);
    } else {
        notifyErrorListeners("incorrect usage of write function", ctx.start);
    }
}
Also used : DataExpression(org.apache.sysml.parser.DataExpression) DataIdentifier(org.apache.sysml.parser.DataIdentifier) HashMap(java.util.HashMap) RelationalExpression(org.apache.sysml.parser.RelationalExpression) BooleanExpression(org.apache.sysml.parser.BooleanExpression) ParameterizedBuiltinFunctionExpression(org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) BinaryExpression(org.apache.sysml.parser.BinaryExpression) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) DataExpression(org.apache.sysml.parser.DataExpression) OutputStatement(org.apache.sysml.parser.OutputStatement)

Example 4 with ParameterExpression

use of org.apache.sysml.parser.ParameterExpression in project incubator-systemml by apache.

the class CommonSyntacticValidator method setPrintStatement.

// -----------------------------------------------------------------
// Helper Functions for exit*FunctionCall*AssignmentStatement
// -----------------------------------------------------------------
protected void setPrintStatement(ParserRuleContext ctx, String functionName, ArrayList<ParameterExpression> paramExpression, StatementInfo thisinfo) {
    if (DMLScript.VALIDATOR_IGNORE_ISSUES == true) {
        // create dummy print statement
        try {
            int line = ctx.start.getLine();
            int col = ctx.start.getCharPositionInLine();
            ArrayList<Expression> expList = new ArrayList<Expression>();
            thisinfo.stmt = new PrintStatement(functionName, expList, line, col, line, col);
            setFileLineColumn(thisinfo.stmt, ctx);
        } catch (LanguageException e) {
            e.printStackTrace();
        }
        return;
    }
    int numParams = paramExpression.size();
    if (numParams == 0) {
        notifyErrorListeners(functionName + "() must have more than 0 parameters", ctx.start);
        return;
    } else if (numParams == 1) {
        Expression expr = paramExpression.get(0).getExpr();
        if (expr == null) {
            notifyErrorListeners("cannot process " + functionName + "() function", ctx.start);
            return;
        }
        try {
            int line = ctx.start.getLine();
            int col = ctx.start.getCharPositionInLine();
            ArrayList<Expression> expList = new ArrayList<Expression>();
            expList.add(expr);
            thisinfo.stmt = new PrintStatement(functionName, expList, line, col, line, col);
            setFileLineColumn(thisinfo.stmt, ctx);
        } catch (LanguageException e) {
            notifyErrorListeners("cannot process " + functionName + "() function", ctx.start);
            return;
        }
    } else if (numParams > 1) {
        if ("stop".equals(functionName)) {
            notifyErrorListeners("stop() function cannot have more than 1 parameter", ctx.start);
            return;
        }
        Expression firstExp = paramExpression.get(0).getExpr();
        if (firstExp == null) {
            notifyErrorListeners("cannot process " + functionName + "() function", ctx.start);
            return;
        }
        if (!(firstExp instanceof StringIdentifier)) {
            notifyErrorListeners("printf-style functionality requires first print parameter to be a string", ctx.start);
            return;
        }
        try {
            int line = ctx.start.getLine();
            int col = ctx.start.getCharPositionInLine();
            List<Expression> expressions = new ArrayList<Expression>();
            for (ParameterExpression pe : paramExpression) {
                Expression expression = pe.getExpr();
                expressions.add(expression);
            }
            thisinfo.stmt = new PrintStatement(functionName, expressions, line, col, line, col);
            setFileLineColumn(thisinfo.stmt, ctx);
        } catch (LanguageException e) {
            notifyErrorListeners("cannot process " + functionName + "() function", ctx.start);
            return;
        }
    }
}
Also used : LanguageException(org.apache.sysml.parser.LanguageException) RelationalExpression(org.apache.sysml.parser.RelationalExpression) BooleanExpression(org.apache.sysml.parser.BooleanExpression) ParameterizedBuiltinFunctionExpression(org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) BinaryExpression(org.apache.sysml.parser.BinaryExpression) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) DataExpression(org.apache.sysml.parser.DataExpression) StringIdentifier(org.apache.sysml.parser.StringIdentifier) ParameterExpression(org.apache.sysml.parser.ParameterExpression) ArrayList(java.util.ArrayList) PrintStatement(org.apache.sysml.parser.PrintStatement)

Example 5 with ParameterExpression

use of org.apache.sysml.parser.ParameterExpression in project incubator-systemml by apache.

the class DmlSyntacticValidator method exitFunctionCallMultiAssignmentStatement.

@Override
public void exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {
    String[] names = getQualifiedNames(ctx.name.getText());
    if (names == null) {
        notifyErrorListeners("incorrect function name (only namespace.functionName allowed. Hint: If you are trying to use builtin functions, you can skip the namespace)", ctx.name);
        return;
    }
    String namespace = names[0];
    String functionName = names[1];
    ArrayList<ParameterExpression> paramExpression = getParameterExpressionList(ctx.paramExprs);
    ConvertedDMLSyntax convertedSyntax = convertToDMLSyntax(ctx, namespace, functionName, paramExpression, ctx.name);
    if (convertedSyntax == null) {
        return;
    } else {
        namespace = convertedSyntax.namespace;
        functionName = convertedSyntax.functionName;
        paramExpression = convertedSyntax.paramExpression;
    }
    FunctionCallIdentifier functCall = new FunctionCallIdentifier(paramExpression);
    functCall.setFunctionName(functionName);
    functCall.setFunctionNamespace(namespace);
    final ArrayList<DataIdentifier> targetList = new ArrayList<DataIdentifier>();
    for (DataIdentifierContext dataCtx : ctx.targetList) {
        if (dataCtx.dataInfo.expr instanceof DataIdentifier) {
            targetList.add((DataIdentifier) dataCtx.dataInfo.expr);
        } else {
            notifyErrorListeners("incorrect type for variable ", dataCtx.start);
            return;
        }
    }
    if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) {
        final FunctionCallMultiAssignmentStatementContext fctx = ctx;
        Action f = new Action() {

            @Override
            public void execute(Expression e) {
                setMultiAssignmentStatement(targetList, e, fctx, fctx.info);
            }
        };
        boolean validBIF = buildForBuiltInFunction(ctx, functionName, paramExpression, f);
        if (validBIF)
            return;
    }
    // Override default namespace for imported non-built-in function
    String inferNamespace = (sourceNamespace != null && sourceNamespace.length() > 0 && DMLProgram.DEFAULT_NAMESPACE.equals(namespace)) ? sourceNamespace : namespace;
    functCall.setFunctionNamespace(inferNamespace);
    setMultiAssignmentStatement(targetList, functCall, ctx, ctx.info);
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) FunctionCallMultiAssignmentStatementContext(org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext) ArrayList(java.util.ArrayList) FunctionCallIdentifier(org.apache.sysml.parser.FunctionCallIdentifier) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) DataIdentifierContext(org.apache.sysml.parser.dml.DmlParser.DataIdentifierContext)

Aggregations

ParameterExpression (org.apache.sysml.parser.ParameterExpression)12 Expression (org.apache.sysml.parser.Expression)9 BinaryExpression (org.apache.sysml.parser.BinaryExpression)7 BuiltinFunctionExpression (org.apache.sysml.parser.BuiltinFunctionExpression)7 ArrayList (java.util.ArrayList)6 DataIdentifier (org.apache.sysml.parser.DataIdentifier)5 BooleanExpression (org.apache.sysml.parser.BooleanExpression)4 DataExpression (org.apache.sysml.parser.DataExpression)4 ParameterizedBuiltinFunctionExpression (org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression)4 RelationalExpression (org.apache.sysml.parser.RelationalExpression)4 FunctionCallIdentifier (org.apache.sysml.parser.FunctionCallIdentifier)3 LanguageException (org.apache.sysml.parser.LanguageException)2 StringIdentifier (org.apache.sysml.parser.StringIdentifier)2 ExpressionInfo (org.apache.sysml.parser.common.ExpressionInfo)2 HashMap (java.util.HashMap)1 IntIdentifier (org.apache.sysml.parser.IntIdentifier)1 OutputStatement (org.apache.sysml.parser.OutputStatement)1 PrintStatement (org.apache.sysml.parser.PrintStatement)1 DataIdentifierContext (org.apache.sysml.parser.dml.DmlParser.DataIdentifierContext)1 FunctionCallMultiAssignmentStatementContext (org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext)1