use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class BuiltinUnarySPInstruction method parseInstruction.
public static BuiltinUnarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String opcode = parseUnaryInstruction(str, in, out);
ValueFunction func = Builtin.getBuiltinFnObject(opcode);
return new MatrixBuiltinSPInstruction(new UnaryOperator(func), in, out, opcode, str);
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class CMCOVMRReducer method reduce.
@Override
public void reduce(TaggedFirstSecondIndexes index, Iterator<MatrixValue> values, OutputCollector<MatrixIndexes, MatrixValue> out, Reporter report) throws IOException {
commonSetup(report);
cmNcovCell.setCM_N_COVObject(0, 0, 0);
ValueFunction fn = cmFn.get(index.getTag());
if (covTags.contains(index.getTag()))
fn = covFn;
while (values.hasNext()) {
CM_N_COVCell cell = (CM_N_COVCell) values.next();
try {
fn.execute(cmNcovCell.getCM_N_COVObject(), cell.getCM_N_COVObject());
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
}
//add 0 values back in
/* long totaln=rlens.get(index.getTag())*clens.get(index.getTag());
long zerosToAdd=totaln-(long)(cmNcovCell.getCM_N_COVObject().w);
for(long i=0; i<zerosToAdd; i++)
{
try {
fn.execute(cmNcovCell.getCM_N_COVObject(), zeroObj);
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
}*/
long totaln = rlens.get(index.getTag()) * clens.get(index.getTag());
long zerosToAdd = totaln - (long) (cmNcovCell.getCM_N_COVObject().w);
if (zerosToAdd > 0) {
zeroObj.w = zerosToAdd;
try {
fn.execute(cmNcovCell.getCM_N_COVObject(), zeroObj);
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
}
for (CM_N_COVInstruction in : cmNcovInstructions) {
if (in.input == index.getTag()) {
try {
outCell.setValue(cmNcovCell.getCM_N_COVObject().getRequiredResult(in.getOperator()));
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
ArrayList<Integer> outputIndexes = outputIndexesMapping.get(in.output);
for (int i : outputIndexes) {
collectOutput_N_Increase_Counter(outIndex, outCell, i, report);
// System.out.println("final output: "+outIndex+" -- "+outCell);
}
}
}
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class BuiltinBinaryGPUInstruction method parseInstruction.
public static BuiltinBinaryGPUInstruction parseInstruction(String str) throws DMLRuntimeException {
CPOperand in1 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
CPOperand in2 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
CPOperand out = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 3);
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
// check for valid data type of output
if ((in1.getDataType() == Expression.DataType.MATRIX || in2.getDataType() == Expression.DataType.MATRIX) && out.getDataType() != Expression.DataType.MATRIX)
throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() + " and " + in2.getName() + " must produce a matrix, which " + out.getName() + " is not");
// Determine appropriate Function Object based on opcode
ValueFunction func = Builtin.getBuiltinFnObject(opcode);
// Only for "solve"
if (in1.getDataType() == Expression.DataType.SCALAR && in2.getDataType() == Expression.DataType.SCALAR)
throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on 2 scalars");
else if (in1.getDataType() == Expression.DataType.MATRIX && in2.getDataType() == Expression.DataType.MATRIX)
return new MatrixMatrixBuiltinGPUInstruction(new BinaryOperator(func), in1, in2, out, opcode, str, 2);
else
throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on a matrix and a scalar");
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class BuiltinUnaryGPUInstruction method parseInstruction.
public static BuiltinUnaryGPUInstruction parseInstruction(String str) throws DMLRuntimeException {
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = null;
ValueFunction func = null;
//print or stop or cumulative aggregates
if (parts.length == 4) {
opcode = parts[0];
in.split(parts[1]);
out.split(parts[2]);
func = Builtin.getBuiltinFnObject(opcode);
throw new DMLRuntimeException("The instruction is not supported on GPU:" + str);
// if( Arrays.asList(new String[]{"ucumk+","ucum*","ucummin","ucummax"}).contains(opcode) )
// return new MatrixBuiltinCPInstruction(new UnaryOperator(func,Integer.parseInt(parts[3])), in, out, opcode, str);
// else
// return new ScalarBuiltinCPInstruction(new SimpleOperator(func), in, out, opcode, str);
} else //2+1, general case
{
InstructionUtils.checkNumFields(str, 2);
opcode = parts[0];
in.split(parts[1]);
out.split(parts[2]);
func = Builtin.getBuiltinFnObject(opcode);
if (in.getDataType() == DataType.SCALAR)
throw new DMLRuntimeException("The instruction is not supported on GPU:" + str);
else // return new ScalarBuiltinCPInstruction(new SimpleOperator(func), in, out, opcode, str);
if (in.getDataType() == DataType.MATRIX)
return new MatrixBuiltinGPUInstruction(new UnaryOperator(func), in, out, opcode, str);
}
return null;
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class ParameterizedBuiltinSPInstruction method parseInstruction.
public static ParameterizedBuiltinSPInstruction parseInstruction(String str) throws DMLRuntimeException {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
if (opcode.equalsIgnoreCase("mapgroupedagg")) {
CPOperand target = new CPOperand(parts[1]);
CPOperand groups = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
HashMap<String, String> paramsMap = new HashMap<String, String>();
paramsMap.put(Statement.GAGG_TARGET, target.getName());
paramsMap.put(Statement.GAGG_GROUPS, groups.getName());
paramsMap.put(Statement.GAGG_NUM_GROUPS, parts[4]);
Operator op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
} else {
// last part is always the output
CPOperand out = new CPOperand(parts[parts.length - 1]);
// process remaining parts and build a hash map
HashMap<String, String> paramsMap = constructParameterMap(parts);
// determine the appropriate value function
ValueFunction func = null;
if (opcode.equalsIgnoreCase("groupedagg")) {
// check for mandatory arguments
String fnStr = paramsMap.get("fn");
if (fnStr == null)
throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
if (fnStr.equalsIgnoreCase("centralmoment")) {
if (paramsMap.get("order") == null)
throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
}
Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
} else if (opcode.equalsIgnoreCase("rmempty")) {
boolean bRmEmptyBC = false;
if (parts.length > 6)
bRmEmptyBC = Boolean.parseBoolean(parts[5]);
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, bRmEmptyBC);
} else if (opcode.equalsIgnoreCase("rexpand") || opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("transform") || opcode.equalsIgnoreCase("transformapply") || opcode.equalsIgnoreCase("transformdecode")) {
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
} else {
throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
}
}
}
Aggregations