use of org.apache.sysml.parser.Expression.DataType in project systemml by apache.
the class GenerateClassesForMLContext method generateFunctionCallMethod.
/**
* Obtain method for invoking a script function.
*
* @param scriptFilePath
* the path to a script file
* @param fs
* a SystemML function statement
* @param dmlFunctionCall
* a string representing the invocation of a script function
* @return string representation of a method that performs a function call
*/
public static String generateFunctionCallMethod(String scriptFilePath, FunctionStatement fs, String dmlFunctionCall) {
createFunctionOutputClass(scriptFilePath, fs);
StringBuilder sb = new StringBuilder();
sb.append("public ");
// begin return type
ArrayList<DataIdentifier> oparams = fs.getOutputParams();
if (oparams.size() == 0) {
sb.append("void");
} else if (oparams.size() == 1) {
// if 1 output, no need to encapsulate it, so return the output
// directly
DataIdentifier oparam = oparams.get(0);
String type = getParamTypeAsString(oparam);
sb.append(type);
} else {
String fullFunctionOutputClassName = getFullFunctionOutputClassName(scriptFilePath, fs);
sb.append(fullFunctionOutputClassName);
}
sb.append(" ");
// end return type
sb.append(fs.getName());
sb.append("(");
ArrayList<DataIdentifier> inputParams = fs.getInputParams();
for (int i = 0; i < inputParams.size(); i++) {
if (i > 0) {
sb.append(", ");
}
DataIdentifier inputParam = inputParams.get(i);
/*
* Note: Using Object is currently preferrable to using
* datatype/valuetype to explicitly set the input type to
* Integer/Double/Boolean/String since Object allows the automatic
* handling of things such as automatic conversions from longs to
* ints.
*/
sb.append("Object ");
sb.append(inputParam.getName());
}
sb.append(") {\n");
sb.append("String scriptString = \"" + dmlFunctionCall + "\";\n");
sb.append("org.apache.sysml.api.mlcontext.Script script = new org.apache.sysml.api.mlcontext.Script(scriptString);\n");
if ((inputParams.size() > 0) || (oparams.size() > 0)) {
sb.append("script");
}
for (int i = 0; i < inputParams.size(); i++) {
DataIdentifier inputParam = inputParams.get(i);
String name = inputParam.getName();
sb.append(".in(\"" + name + "\", " + name + ")");
}
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier outputParam = oparams.get(i);
String name = outputParam.getName();
sb.append(".out(\"" + name + "\")");
}
if ((inputParams.size() > 0) || (oparams.size() > 0)) {
sb.append(";\n");
}
sb.append("org.apache.sysml.api.mlcontext.MLResults results = script.execute();\n");
if (oparams.size() == 0) {
sb.append("return;\n");
} else if (oparams.size() == 1) {
DataIdentifier o = oparams.get(0);
DataType dt = o.getDataType();
ValueType vt = o.getValueType();
if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
sb.append("long res = results.getLong(\"" + o.getName() + "\");\nreturn res;\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
sb.append("double res = results.getDouble(\"" + o.getName() + "\");\nreturn res;\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
sb.append("boolean res = results.getBoolean(\"" + o.getName() + "\");\nreturn res;\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
sb.append("String res = results.getString(\"" + o.getName() + "\");\nreturn res;\n");
} else if (dt == DataType.MATRIX) {
sb.append("org.apache.sysml.api.mlcontext.Matrix res = results.getMatrix(\"" + o.getName() + "\");\nreturn res;\n");
} else if (dt == DataType.FRAME) {
sb.append("org.apache.sysml.api.mlcontext.Frame res = results.getFrame(\"" + o.getName() + "\");\nreturn res;\n");
}
} else {
for (int i = 0; i < oparams.size(); i++) {
DataIdentifier outputParam = oparams.get(i);
String name = outputParam.getName().toLowerCase();
String type = getParamTypeAsString(outputParam);
DataType dt = outputParam.getDataType();
ValueType vt = outputParam.getValueType();
if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
sb.append(type + " " + name + " = results.getLong(\"" + outputParam.getName() + "\");\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
sb.append(type + " " + name + " = results.getDouble(\"" + outputParam.getName() + "\");\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
sb.append(type + " " + name + " = results.getBoolean(\"" + outputParam.getName() + "\");\n");
} else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
sb.append(type + " " + name + " = results.getString(\"" + outputParam.getName() + "\");\n");
} else if (dt == DataType.MATRIX) {
sb.append(type + " " + name + " = results.getMatrix(\"" + outputParam.getName() + "\");\n");
} else if (dt == DataType.FRAME) {
sb.append(type + " " + name + " = results.getFrame(\"" + outputParam.getName() + "\");\n");
}
}
String ffocn = getFullFunctionOutputClassName(scriptFilePath, fs);
sb.append(ffocn + " res = new " + ffocn + "(");
for (int i = 0; i < oparams.size(); i++) {
if (i > 0) {
sb.append(", ");
}
DataIdentifier outputParam = oparams.get(i);
String name = outputParam.getName().toLowerCase();
sb.append(name);
}
sb.append(");\nreturn res;\n");
}
sb.append("}\n");
return sb.toString();
}
use of org.apache.sysml.parser.Expression.DataType in project systemml by apache.
the class GenerateClassesForMLContext method getSimpleParamTypeAsString.
/**
* Obtain a string representation of a parameter type, where a Matrix or
* Frame is represented by its simple class name.
*
* @param param
* the function parameter
* @return string representation of a parameter type
*/
public static String getSimpleParamTypeAsString(DataIdentifier param) {
DataType dt = param.getDataType();
ValueType vt = param.getValueType();
if ((dt == DataType.SCALAR) && (vt == ValueType.INT)) {
return "long";
} else if ((dt == DataType.SCALAR) && (vt == ValueType.DOUBLE)) {
return "double";
} else if ((dt == DataType.SCALAR) && (vt == ValueType.BOOLEAN)) {
return "boolean";
} else if ((dt == DataType.SCALAR) && (vt == ValueType.STRING)) {
return "String";
} else if (dt == DataType.MATRIX) {
return "Matrix";
} else if (dt == DataType.FRAME) {
return "Frame";
}
return null;
}
use of org.apache.sysml.parser.Expression.DataType in project systemml by apache.
the class ParForStatementBlock method rCheckCandidates.
/**
* This method recursively checks a candidate against StatementBlocks for anti, data and output dependencies.
* A LanguageException is raised if at least one dependency is found, where it is guaranteed that no false negatives
* (undetected dependency) but potentially false positives (misdetected dependency) can appear.
*
* @param c candidate
* @param cdt candidate data type
* @param asb list of statement blocks
* @param sCount statement count
* @param dep array of boolean potential output dependencies
*/
private void rCheckCandidates(Candidate c, DataType cdt, ArrayList<StatementBlock> asb, Integer sCount, boolean[] dep) {
// check candidate only (output dependency if scalar or constant matrix subscript)
if (cdt == DataType.SCALAR || // dat2 checked for other candidate
cdt == DataType.OBJECT) {
// every write to a scalar or complete data object is an output dependency
dep[0] = true;
if (ABORT_ON_FIRST_DEPENDENCY)
return;
} else if (cdt == DataType.MATRIX) {
if (runConstantCheck(c._dat) && !c._isAccum) {
if (LOG.isTraceEnabled())
LOG.trace("PARFOR: Possible output dependency detected via constant self-check: var '" + c._var + "'.");
dep[0] = true;
if (ABORT_ON_FIRST_DEPENDENCY)
return;
}
}
// check candidate against all statements
for (StatementBlock sb : asb) for (Statement s : sb._statements) {
sCount++;
if (s instanceof ForStatement) {
// incl parfor
// despite separate dependency analysis for each nested parfor, we need to
// recursively check nested parfor as well in order to ensure correcteness
// of constantChecks with regard to outer indexes
rCheckCandidates(c, cdt, ((ForStatement) s).getBody(), sCount, dep);
} else if (s instanceof WhileStatement) {
rCheckCandidates(c, cdt, ((WhileStatement) s).getBody(), sCount, dep);
} else if (s instanceof IfStatement) {
rCheckCandidates(c, cdt, ((IfStatement) s).getIfBody(), sCount, dep);
rCheckCandidates(c, cdt, ((IfStatement) s).getElseBody(), sCount, dep);
} else if (s instanceof FunctionStatement) {
rCheckCandidates(c, cdt, ((FunctionStatement) s).getBody(), sCount, dep);
} else {
// CHECK output dependencies
List<DataIdentifier> datsUpdated = getDataIdentifiers(s, true);
if (datsUpdated != null) {
for (DataIdentifier write : datsUpdated) {
if (!c._var.equals(write.getName()))
continue;
if (cdt != DataType.MATRIX) {
// cannot infer type, need to exit (conservative approach)
throw new LanguageException("PARFOR loop dependency analysis: " + "cannot check for dependencies due to unknown datatype of var '" + c._var + "'.");
}
DataIdentifier dat2 = write;
// skip self-check
if (c._dat == dat2)
continue;
if (runEqualsCheck(c._dat, dat2)) {
// intra-iteration output dependencies (same index function) are OK
} else if (runBanerjeeGCDTest(c._dat, dat2)) {
LOG.trace("PARFOR: Possible output dependency detected via GCD/Banerjee: var '" + write + "'.");
dep[0] = true;
if (ABORT_ON_FIRST_DEPENDENCY)
return;
}
}
}
List<DataIdentifier> datsRead = getDataIdentifiers(s, false);
if (datsRead == null)
continue;
// check data and anti dependencies
for (DataIdentifier read : datsRead) {
if (!c._var.equals(read.getName()))
continue;
DataIdentifier dat2 = read;
DataType dat2dt = _vsParent.getVariables().get(read.getName()).getDataType();
if (cdt == DataType.SCALAR || cdt == DataType.OBJECT || dat2dt == DataType.SCALAR || dat2dt == DataType.OBJECT) {
// every write, read combination involving a scalar is a data dependency
dep[1] = true;
if (ABORT_ON_FIRST_DEPENDENCY)
return;
} else if (cdt == DataType.MATRIX && dat2dt == DataType.MATRIX) {
boolean invalid = false;
if (runEqualsCheck(c._dat, dat2))
// read/write on same index, and not constant (checked for output) is OK
invalid = runConstantCheck(dat2);
else if (runBanerjeeGCDTest(c._dat, dat2))
invalid = true;
else if (!(dat2 instanceof IndexedIdentifier))
// non-indexed access to candidate result variable -> always a dependency
invalid = true;
if (invalid) {
LOG.trace("PARFOR: Possible data/anti dependency detected via GCD/Banerjee: var '" + read + "'.");
dep[1] = true;
dep[2] = true;
if (ABORT_ON_FIRST_DEPENDENCY)
return;
}
} else {
// cannot infer type, need to exit (conservative approach)
throw new LanguageException("PARFOR loop dependency analysis: " + "cannot check for dependencies due to unknown datatype of var '" + c._var + "'.");
}
}
}
}
}
use of org.apache.sysml.parser.Expression.DataType in project systemml by apache.
the class ScalarInstruction method isFirstArgumentScalar.
private static boolean isFirstArgumentScalar(String inst) {
// get first argument
String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
String arg1 = parts[1];
// get data type of first argument
String[] subparts = arg1.split(Lop.VALUETYPE_PREFIX);
DataType dt = DataType.valueOf(subparts[1]);
return (dt == DataType.SCALAR);
}
use of org.apache.sysml.parser.Expression.DataType in project systemml by apache.
the class BinarySPInstruction method parseInstruction.
public static BinarySPInstruction parseInstruction(String str) {
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String opcode = null;
boolean isBroadcast = false;
VectorType vtype = null;
if (str.startsWith("SPARK" + Lop.OPERAND_DELIMITOR + "map")) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 5);
opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
vtype = VectorType.valueOf(parts[5]);
isBroadcast = true;
} else {
opcode = parseBinaryInstruction(str, in1, in2, out);
}
DataType dt1 = in1.getDataType();
DataType dt2 = in2.getDataType();
Operator operator = InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2);
if (dt1 == DataType.MATRIX || dt2 == DataType.MATRIX) {
if (dt1 == DataType.MATRIX && dt2 == DataType.MATRIX) {
if (isBroadcast)
return new BinaryMatrixBVectorSPInstruction(operator, in1, in2, out, vtype, opcode, str);
else
return new BinaryMatrixMatrixSPInstruction(operator, in1, in2, out, opcode, str);
} else
return new BinaryMatrixScalarSPInstruction(operator, in1, in2, out, opcode, str);
}
return null;
}
Aggregations