Search in sources :

Example 6 with IntIdentifier

use of org.apache.sysml.parser.IntIdentifier in project systemml by apache.

the class CommonSyntacticValidator method getConstIdFromString.

protected ConstIdentifier getConstIdFromString(ParserRuleContext ctx, String varValue) {
    // Compare to "True/TRUE"
    if (varValue.equals(trueStringLiteral()))
        return new BooleanIdentifier(ctx, true, currentFile);
    // Compare to "False/FALSE"
    if (varValue.equals(falseStringLiteral()))
        return new BooleanIdentifier(ctx, false, currentFile);
    // Also the alternative of Ints.tryParse and falling back to double would not be lossless in all cases.
    try {
        long lval = Long.parseLong(varValue);
        return new IntIdentifier(ctx, lval, currentFile);
    } catch (Exception ex) {
    // continue
    }
    // NOTE: we use exception handling instead of Doubles.tryParse for backwards compatibility with guava <14.0
    try {
        double dval = Double.parseDouble(varValue);
        return new DoubleIdentifier(ctx, dval, currentFile);
    } catch (Exception ex) {
    // continue
    }
    // Otherwise it is a string literal (optionally enclosed within single or double quotes)
    String val = "";
    String text = varValue;
    if ((text.startsWith("\"") && text.endsWith("\"")) || (text.startsWith("\'") && text.endsWith("\'"))) {
        if (text.length() > 2) {
            val = extractStringInQuotes(text, true);
        }
    } else {
        // the commandline parameters can be passed without any quotes
        val = extractStringInQuotes(text, false);
    }
    return new StringIdentifier(ctx, val, currentFile);
}
Also used : DoubleIdentifier(org.apache.sysml.parser.DoubleIdentifier) IntIdentifier(org.apache.sysml.parser.IntIdentifier) StringIdentifier(org.apache.sysml.parser.StringIdentifier) BooleanIdentifier(org.apache.sysml.parser.BooleanIdentifier) LanguageException(org.apache.sysml.parser.LanguageException)

Example 7 with IntIdentifier

use of org.apache.sysml.parser.IntIdentifier in project incubator-systemml by apache.

the class CommonSyntacticValidator method getConstIdFromString.

protected ConstIdentifier getConstIdFromString(String varValue, Token start) {
    int linePosition = start.getLine();
    int charPosition = start.getCharPositionInLine();
    // Compare to "True/TRUE"
    if (varValue.equals(trueStringLiteral()))
        return new BooleanIdentifier(true, currentFile, linePosition, charPosition, linePosition, charPosition);
    // Compare to "False/FALSE"
    if (varValue.equals(falseStringLiteral()))
        return new BooleanIdentifier(false, currentFile, linePosition, charPosition, linePosition, charPosition);
    // Also the alternative of Ints.tryParse and falling back to double would not be lossless in all cases. 
    try {
        long lval = Long.parseLong(varValue);
        return new IntIdentifier(lval, currentFile, linePosition, charPosition, linePosition, charPosition);
    } catch (Exception ex) {
    //continue
    }
    // NOTE: we use exception handling instead of Doubles.tryParse for backwards compatibility with guava <14.0
    try {
        double dval = Double.parseDouble(varValue);
        return new DoubleIdentifier(dval, currentFile, linePosition, charPosition, linePosition, charPosition);
    } catch (Exception ex) {
    //continue
    }
    // Otherwise it is a string literal (optionally enclosed within single or double quotes)
    String val = "";
    String text = varValue;
    if ((text.startsWith("\"") && text.endsWith("\"")) || (text.startsWith("\'") && text.endsWith("\'"))) {
        if (text.length() > 2) {
            val = extractStringInQuotes(text, true);
        }
    } else {
        // the commandline parameters can be passed without any quotes
        val = extractStringInQuotes(text, false);
    }
    return new StringIdentifier(val, currentFile, linePosition, charPosition, linePosition, charPosition);
}
Also used : DoubleIdentifier(org.apache.sysml.parser.DoubleIdentifier) IntIdentifier(org.apache.sysml.parser.IntIdentifier) StringIdentifier(org.apache.sysml.parser.StringIdentifier) BooleanIdentifier(org.apache.sysml.parser.BooleanIdentifier) LanguageException(org.apache.sysml.parser.LanguageException)

Example 8 with IntIdentifier

use of org.apache.sysml.parser.IntIdentifier in project incubator-systemml by apache.

the class PydmlSyntacticValidator method exitIndexedExpression.

/**
 * PyDML uses 0-based indexing, so we increment lower indices by 1
 * when translating to DML.
 *
 * @param ctx the parse tree
 */
@Override
public void exitIndexedExpression(IndexedExpressionContext ctx) {
    boolean isRowLower = (ctx.rowLower != null && !ctx.rowLower.isEmpty() && (ctx.rowLower.info.expr != null));
    boolean isRowUpper = (ctx.rowUpper != null && !ctx.rowUpper.isEmpty() && (ctx.rowUpper.info.expr != null));
    boolean isColLower = (ctx.colLower != null && !ctx.colLower.isEmpty() && (ctx.colLower.info.expr != null));
    boolean isColUpper = (ctx.colUpper != null && !ctx.colUpper.isEmpty() && (ctx.colUpper.info.expr != null));
    boolean isRowSliceImplicit = ctx.rowImplicitSlice != null;
    boolean isColSliceImplicit = ctx.colImplicitSlice != null;
    ExpressionInfo rowLower = isRowLower ? ctx.rowLower.info : null;
    ExpressionInfo rowUpper = isRowUpper ? ctx.rowUpper.info : null;
    ExpressionInfo colLower = isColLower ? ctx.colLower.info : null;
    ExpressionInfo colUpper = isColUpper ? ctx.colUpper.info : null;
    ctx.dataInfo.expr = new IndexedIdentifier(ctx.name.getText(), false, false);
    setFileLineColumn(ctx.dataInfo.expr, ctx);
    try {
        ArrayList<ArrayList<Expression>> exprList = new ArrayList<>();
        ArrayList<Expression> rowIndices = new ArrayList<>();
        ArrayList<Expression> colIndices = new ArrayList<>();
        if (!isRowLower && !isRowUpper) {
            // both not set
            rowIndices.add(null);
            rowIndices.add(null);
        } else if (isRowLower && isRowUpper) {
            // both set
            rowIndices.add(incrementByOne(rowLower.expr, ctx));
            rowIndices.add(rowUpper.expr);
        } else if (isRowLower && !isRowUpper) {
            // Add given lower bound
            rowIndices.add(incrementByOne(rowLower.expr, ctx));
            if (isRowSliceImplicit) {
                // Add expression for nrow(X) for implicit upper bound
                Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NROW;
                DataIdentifier x = new DataIdentifier(ctx.name.getText());
                Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
                rowIndices.add(expr);
            }
        } else if (!isRowLower && isRowUpper && isRowSliceImplicit) {
            // Add expression for `1` for implicit lower bound
            // Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
            IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
            rowIndices.add(one);
            // Add given upper bound
            rowIndices.add(rowUpper.expr);
        } else {
            notifyErrorListeners("incorrect index expression for row", ctx.start);
            return;
        }
        if (!isColLower && !isColUpper) {
            // both not set
            colIndices.add(null);
            colIndices.add(null);
        } else if (isColLower && isColUpper) {
            colIndices.add(incrementByOne(colLower.expr, ctx));
            colIndices.add(colUpper.expr);
        } else if (isColLower && !isColUpper) {
            // Add given lower bound
            colIndices.add(incrementByOne(colLower.expr, ctx));
            if (isColSliceImplicit) {
                // Add expression for ncol(X) for implicit upper bound
                Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NCOL;
                DataIdentifier x = new DataIdentifier(ctx.name.getText());
                Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
                colIndices.add(expr);
            }
        } else if (!isColLower && isColUpper && isColSliceImplicit) {
            // Add expression for `1` for implicit lower bound
            // Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
            IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
            colIndices.add(one);
            // Add given upper bound
            colIndices.add(colUpper.expr);
        } else {
            notifyErrorListeners("incorrect index expression for column", ctx.start);
            return;
        }
        exprList.add(rowIndices);
        exprList.add(colIndices);
        ((IndexedIdentifier) ctx.dataInfo.expr).setIndices(exprList);
    } catch (Exception e) {
        notifyErrorListeners("cannot set the indices", ctx.start);
        return;
    }
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) BinaryExpression(org.apache.sysml.parser.BinaryExpression) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) IntIdentifier(org.apache.sysml.parser.IntIdentifier) ArrayList(java.util.ArrayList) ExpressionInfo(org.apache.sysml.parser.common.ExpressionInfo) LanguageException(org.apache.sysml.parser.LanguageException) ParseException(org.apache.sysml.parser.ParseException) IndexedIdentifier(org.apache.sysml.parser.IndexedIdentifier)

Example 9 with IntIdentifier

use of org.apache.sysml.parser.IntIdentifier in project systemml by apache.

the class PydmlSyntacticValidator method convertPythonBuiltinFunctionToDMLSyntax.

// TODO : Clean up to use Map or some other structure
/**
 * Check function name, namespace, parameters (#params & possible values) and produce useful messages/hints
 * @param ctx antlr rule context
 * @param namespace Namespace of the function
 * @param functionName Name of the builtin function
 * @param paramExpression Array of parameter names and values
 * @param fnName Token of the builtin function identifier
 * @return common syntax format for runtime
 */
private ConvertedDMLSyntax convertPythonBuiltinFunctionToDMLSyntax(ParserRuleContext ctx, String namespace, String functionName, ArrayList<ParameterExpression> paramExpression, Token fnName) {
    if (sources.containsValue(namespace) || functions.contains(functionName)) {
        return new ConvertedDMLSyntax(namespace, functionName, paramExpression);
    }
    if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("len")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 1 arguments", fnName);
            return null;
        }
        functionName = "length";
    } else if (functionName.equals("sum") || functionName.equals("mean") || functionName.equals("avg") || functionName.equals("min") || functionName.equals("max") || functionName.equals("argmax") || functionName.equals("argmin") || functionName.equals("cumsum") || functionName.equals("transpose") || functionName.equals("trace") || functionName.equals("var") || functionName.equals("sd")) {
        // can mean sum of all cells or row-wise or columnwise sum
        if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && paramExpression.size() == 1) {
            // otherwise same function name
            if (functionName.equals("avg")) {
                functionName = "mean";
            } else if (functionName.equals("transpose")) {
                functionName = "t";
            } else if (functionName.equals("argmax") || functionName.equals("argmin") || functionName.equals("cumsum")) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for entire matrix is not supported", fnName);
                return null;
            }
        } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && paramExpression.size() == 0) {
            // x.sum() => sum(x)
            paramExpression = new ArrayList<>();
            paramExpression.add(new ParameterExpression(null, new DataIdentifier(namespace)));
            // otherwise same function name
            if (functionName.equals("avg")) {
                functionName = "mean";
            } else if (functionName.equals("transpose")) {
                functionName = "t";
            } else if (functionName.equals("argmax") || functionName.equals("argmin") || functionName.equals("cumsum")) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for entire matrix is not supported", fnName);
                return null;
            }
        } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && paramExpression.size() == 2) {
            // sum(x, axis=1) => rowSums(x)
            int axis = getAxis(paramExpression.get(1));
            if (axis == -1 && (functionName.equals("min") || functionName.equals("max"))) {
            // Do nothing
            // min(2, 3)
            } else if (axis == -1) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                return null;
            } else {
                ArrayList<ParameterExpression> temp = new ArrayList<>();
                temp.add(paramExpression.get(0));
                paramExpression = temp;
                functionName = getPythonAggFunctionNames(functionName, axis);
                if (functionName.equals("Not Supported")) {
                    notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                    return null;
                }
            }
        } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && paramExpression.size() == 1) {
            // x.sum(axis=1) => rowSums(x)
            int axis = getAxis(paramExpression.get(0));
            if (axis == -1) {
                notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                return null;
            } else {
                paramExpression = new ArrayList<>();
                paramExpression.add(new ParameterExpression(null, new DataIdentifier(namespace)));
                functionName = getPythonAggFunctionNames(functionName, axis);
                if (functionName.equals("Not Supported")) {
                    notifyErrorListeners("The builtin function \'" + functionName + "\' for given arguments is not supported", fnName);
                    return null;
                }
            }
        } else {
            notifyErrorListeners("Incorrect number of arguments for the builtin function \'" + functionName + "\'.", fnName);
            return null;
        }
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("concatenate")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 2 arguments (Note: concatenate append columns of two matrices)", fnName);
            return null;
        }
        functionName = "append";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("minimum")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 2 arguments", fnName);
            return null;
        }
        functionName = "min";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("maximum")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 2 arguments", fnName);
            return null;
        }
        functionName = "max";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && functionName.equals("shape")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts only 1 argument (0 or 1)", fnName);
            return null;
        }
        int axis = getAxis(paramExpression.get(0));
        if (axis == -1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts only 1 argument (0 or 1)", fnName);
            return null;
        }
        paramExpression = new ArrayList<>();
        paramExpression.add(new ParameterExpression(null, new DataIdentifier(namespace)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
        if (axis == 0) {
            functionName = "nrow";
        } else if (axis == 1) {
            functionName = "ncol";
        }
    } else if (namespace.equals("random") && functionName.equals("normal")) {
        if (paramExpression.size() != 3) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (number of rows, number of columns, sparsity)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        paramExpression.get(2).setName("sparsity");
        paramExpression.add(new ParameterExpression("pdf", new StringIdentifier(ctx, "normal", currentFile)));
        functionName = "rand";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("random") && functionName.equals("poisson")) {
        if (paramExpression.size() != 4) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (number of rows, number of columns, sparsity, lambda)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        paramExpression.get(2).setName("sparsity");
        paramExpression.get(3).setName("lambda");
        paramExpression.add(new ParameterExpression("pdf", new StringIdentifier(ctx, "poisson", currentFile)));
        functionName = "rand";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("random") && functionName.equals("uniform")) {
        if (paramExpression.size() != 5) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 5 arguments (number of rows, number of columns, sparsity, min, max)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        paramExpression.get(2).setName("sparsity");
        paramExpression.get(3).setName("min");
        paramExpression.get(4).setName("max");
        paramExpression.add(new ParameterExpression("pdf", new StringIdentifier(ctx, "uniform", currentFile)));
        functionName = "rand";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("full")) {
        if (paramExpression.size() != 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 3 arguments (constant float value, number of rows, number of columns)", fnName);
            return null;
        }
        paramExpression.get(1).setName("rows");
        paramExpression.get(2).setName("cols");
        functionName = "matrix";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("matrix")) {
        // This can either be string initializer or as.matrix function
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument (either str or float value)", fnName);
            return null;
        }
        if (paramExpression.get(0).getExpr() instanceof StringIdentifier) {
            String initializerString = ((StringIdentifier) paramExpression.get(0).getExpr()).getValue().trim();
            if (!initializerString.startsWith("[") || !initializerString.endsWith("]")) {
                notifyErrorListeners("Incorrect initializer string for builtin function \'" + functionName + "\' (Eg: matrix(\"[1 2 3; 4 5 6]\"))", fnName);
                return null;
            }
            int rows = StringUtils.countMatches(initializerString, ";") + 1;
            // Make sure user doesnot have pretty string
            initializerString = initializerString.replaceAll("; ", ";");
            initializerString = initializerString.replaceAll(" ;", ";");
            initializerString = initializerString.replaceAll("\\[ ", "\\[");
            initializerString = initializerString.replaceAll(" \\]", "\\]");
            // Each row has ncol-1 spaces
            // #spaces = nrow * (ncol-1)
            // ncol = (#spaces / nrow) + 1
            int cols = (StringUtils.countMatches(initializerString, " ") / rows) + 1;
            initializerString = initializerString.replaceAll(";", " ");
            initializerString = initializerString.replaceAll("\\[", "");
            initializerString = initializerString.replaceAll("\\]", "");
            paramExpression = new ArrayList<>();
            paramExpression.add(new ParameterExpression(null, new StringIdentifier(ctx, initializerString, currentFile)));
            paramExpression.add(new ParameterExpression("rows", new IntIdentifier(ctx, rows, currentFile)));
            paramExpression.add(new ParameterExpression("cols", new IntIdentifier(ctx, cols, currentFile)));
        } else {
            functionName = "as.matrix";
        }
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("scalar")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.scalar";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("float")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.double";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("int")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.integer";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("bool")) {
        if (paramExpression.size() != 1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 1 argument", fnName);
            return null;
        }
        functionName = "as.logical";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (!(namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) && functionName.equals("reshape")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments (number of rows, number of columns)", fnName);
            return null;
        }
        paramExpression.get(0).setName("rows");
        paramExpression.get(1).setName("cols");
        ArrayList<ParameterExpression> temp = new ArrayList<>();
        temp.add(new ParameterExpression(null, new DataIdentifier(namespace)));
        temp.add(paramExpression.get(0));
        temp.add(paramExpression.get(1));
        paramExpression = temp;
        functionName = "matrix";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("removeEmpty")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments (matrix, axis=0 or 1)", fnName);
            return null;
        }
        int axis = getAxis(paramExpression.get(1));
        if (axis == -1) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments (matrix, axis=0 or 1)", fnName);
            return null;
        }
        StringIdentifier marginVal = null;
        if (axis == 0) {
            marginVal = new StringIdentifier(ctx, "rows", currentFile);
        } else {
            marginVal = new StringIdentifier(ctx, "cols", currentFile);
        }
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("margin");
        paramExpression.get(1).setExpr(marginVal);
        functionName = "removeEmpty";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("replace")) {
        if (paramExpression.size() != 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 3 arguments (matrix, scalar value that should be replaced (pattern), scalar value (replacement))", fnName);
            return null;
        }
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("pattern");
        paramExpression.get(2).setName("replacement");
        functionName = "replace";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("range")) {
        if (paramExpression.size() < 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 3 arguments (from, to, increment), with the first 2 lacking default values", fnName);
            return null;
        } else if (paramExpression.size() > 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts 3 arguments (from, to, increment)", fnName);
        }
        functionName = "seq";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("norm") && functionName.equals("cdf")) {
        if (paramExpression.size() != 3) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (target, mean, sd)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("mean");
        paramExpression.get(2).setName("sd");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier(ctx, "normal", currentFile)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("expon") && functionName.equals("cdf")) {
        if (paramExpression.size() != 2) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 2 arguments (target, mean)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("mean");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier(ctx, "exp", currentFile)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("chi") && functionName.equals("cdf")) {
        if (paramExpression.size() != 2) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 2 arguments (target, df)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("df");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier(ctx, "chisq", currentFile)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("f") && functionName.equals("cdf")) {
        if (paramExpression.size() != 3) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 3 arguments (target, df1, df2)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("df1");
        paramExpression.get(2).setName("df2");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier(ctx, "f", currentFile)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals("t") && functionName.equals("cdf")) {
        if (paramExpression.size() != 2) {
            String qualifiedName = namespace + namespaceResolutionOp() + functionName;
            notifyErrorListeners("The builtin function \'" + qualifiedName + "\' accepts exactly 2 arguments (target, df)", fnName);
            return null;
        }
        functionName = "cdf";
        paramExpression.get(0).setName("target");
        paramExpression.get(1).setName("df");
        paramExpression.add(new ParameterExpression("dist", new StringIdentifier(ctx, "t", currentFile)));
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("percentile")) {
        if (paramExpression.size() != 2 && paramExpression.size() != 3) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts either 2 or 3 arguments", fnName);
            return null;
        }
        functionName = "quantile";
        namespace = DMLProgram.DEFAULT_NAMESPACE;
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("arcsin")) {
        functionName = "asin";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("arccos")) {
        functionName = "acos";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("arctan")) {
        functionName = "atan";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("load")) {
        functionName = "read";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("eigen")) {
        functionName = "eig";
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("power")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments", fnName);
            return null;
        }
    } else if (namespace.equals(DMLProgram.DEFAULT_NAMESPACE) && functionName.equals("dot")) {
        if (paramExpression.size() != 2) {
            notifyErrorListeners("The builtin function \'" + functionName + "\' accepts exactly 2 arguments", fnName);
            return null;
        }
    }
    return new ConvertedDMLSyntax(namespace, functionName, paramExpression);
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) IntIdentifier(org.apache.sysml.parser.IntIdentifier) StringIdentifier(org.apache.sysml.parser.StringIdentifier) ParameterExpression(org.apache.sysml.parser.ParameterExpression) ArrayList(java.util.ArrayList)

Example 10 with IntIdentifier

use of org.apache.sysml.parser.IntIdentifier in project systemml by apache.

the class PydmlSyntacticValidator method exitIndexedExpression.

/**
 * PyDML uses 0-based indexing, so we increment lower indices by 1
 * when translating to DML.
 *
 * @param ctx the parse tree
 */
@Override
public void exitIndexedExpression(IndexedExpressionContext ctx) {
    boolean isRowLower = (ctx.rowLower != null && !ctx.rowLower.isEmpty() && (ctx.rowLower.info.expr != null));
    boolean isRowUpper = (ctx.rowUpper != null && !ctx.rowUpper.isEmpty() && (ctx.rowUpper.info.expr != null));
    boolean isColLower = (ctx.colLower != null && !ctx.colLower.isEmpty() && (ctx.colLower.info.expr != null));
    boolean isColUpper = (ctx.colUpper != null && !ctx.colUpper.isEmpty() && (ctx.colUpper.info.expr != null));
    boolean isRowSliceImplicit = ctx.rowImplicitSlice != null;
    boolean isColSliceImplicit = ctx.colImplicitSlice != null;
    ExpressionInfo rowLower = isRowLower ? ctx.rowLower.info : null;
    ExpressionInfo rowUpper = isRowUpper ? ctx.rowUpper.info : null;
    ExpressionInfo colLower = isColLower ? ctx.colLower.info : null;
    ExpressionInfo colUpper = isColUpper ? ctx.colUpper.info : null;
    ctx.dataInfo.expr = new IndexedIdentifier(ctx.name.getText(), false, false);
    setFileLineColumn(ctx.dataInfo.expr, ctx);
    try {
        ArrayList<ArrayList<Expression>> exprList = new ArrayList<>();
        ArrayList<Expression> rowIndices = new ArrayList<>();
        ArrayList<Expression> colIndices = new ArrayList<>();
        if (!isRowLower && !isRowUpper) {
            // both not set
            rowIndices.add(null);
            rowIndices.add(null);
        } else if (isRowLower && isRowUpper) {
            // both set
            rowIndices.add(incrementByOne(rowLower.expr, ctx));
            rowIndices.add(rowUpper.expr);
        } else if (isRowLower && !isRowUpper) {
            // Add given lower bound
            rowIndices.add(incrementByOne(rowLower.expr, ctx));
            if (isRowSliceImplicit) {
                // Add expression for nrow(X) for implicit upper bound
                Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NROW;
                DataIdentifier x = new DataIdentifier(ctx.name.getText());
                Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
                rowIndices.add(expr);
            }
        } else if (!isRowLower && isRowUpper && isRowSliceImplicit) {
            // Add expression for `1` for implicit lower bound
            // Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
            IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
            rowIndices.add(one);
            // Add given upper bound
            rowIndices.add(rowUpper.expr);
        } else {
            notifyErrorListeners("incorrect index expression for row", ctx.start);
            return;
        }
        if (!isColLower && !isColUpper) {
            // both not set
            colIndices.add(null);
            colIndices.add(null);
        } else if (isColLower && isColUpper) {
            colIndices.add(incrementByOne(colLower.expr, ctx));
            colIndices.add(colUpper.expr);
        } else if (isColLower && !isColUpper) {
            // Add given lower bound
            colIndices.add(incrementByOne(colLower.expr, ctx));
            if (isColSliceImplicit) {
                // Add expression for ncol(X) for implicit upper bound
                Expression.BuiltinFunctionOp bop = Expression.BuiltinFunctionOp.NCOL;
                DataIdentifier x = new DataIdentifier(ctx.name.getText());
                Expression expr = new BuiltinFunctionExpression(ctx, bop, new Expression[] { x }, currentFile);
                colIndices.add(expr);
            }
        } else if (!isColLower && isColUpper && isColSliceImplicit) {
            // Add expression for `1` for implicit lower bound
            // Note: We go ahead and increment by 1 to convert from 0-based to 1-based indexing
            IntIdentifier one = new IntIdentifier(ctx, 1, currentFile);
            colIndices.add(one);
            // Add given upper bound
            colIndices.add(colUpper.expr);
        } else {
            notifyErrorListeners("incorrect index expression for column", ctx.start);
            return;
        }
        exprList.add(rowIndices);
        exprList.add(colIndices);
        ((IndexedIdentifier) ctx.dataInfo.expr).setIndices(exprList);
    } catch (Exception e) {
        notifyErrorListeners("cannot set the indices", ctx.start);
        return;
    }
}
Also used : DataIdentifier(org.apache.sysml.parser.DataIdentifier) BinaryExpression(org.apache.sysml.parser.BinaryExpression) Expression(org.apache.sysml.parser.Expression) ParameterExpression(org.apache.sysml.parser.ParameterExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) BuiltinFunctionExpression(org.apache.sysml.parser.BuiltinFunctionExpression) IntIdentifier(org.apache.sysml.parser.IntIdentifier) ArrayList(java.util.ArrayList) ExpressionInfo(org.apache.sysml.parser.common.ExpressionInfo) LanguageException(org.apache.sysml.parser.LanguageException) ParseException(org.apache.sysml.parser.ParseException) IndexedIdentifier(org.apache.sysml.parser.IndexedIdentifier)

Aggregations

IntIdentifier (org.apache.sysml.parser.IntIdentifier)10 LanguageException (org.apache.sysml.parser.LanguageException)7 ParameterExpression (org.apache.sysml.parser.ParameterExpression)5 StringIdentifier (org.apache.sysml.parser.StringIdentifier)5 ArrayList (java.util.ArrayList)4 DataIdentifier (org.apache.sysml.parser.DataIdentifier)4 DoubleIdentifier (org.apache.sysml.parser.DoubleIdentifier)4 BinaryExpression (org.apache.sysml.parser.BinaryExpression)3 BooleanIdentifier (org.apache.sysml.parser.BooleanIdentifier)3 BuiltinFunctionExpression (org.apache.sysml.parser.BuiltinFunctionExpression)3 Expression (org.apache.sysml.parser.Expression)3 IndexedIdentifier (org.apache.sysml.parser.IndexedIdentifier)2 ParseException (org.apache.sysml.parser.ParseException)2 ExpressionInfo (org.apache.sysml.parser.common.ExpressionInfo)2 Token (org.antlr.v4.runtime.Token)1 BooleanExpression (org.apache.sysml.parser.BooleanExpression)1 DataExpression (org.apache.sysml.parser.DataExpression)1 ParameterizedBuiltinFunctionExpression (org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression)1 RelationalExpression (org.apache.sysml.parser.RelationalExpression)1