use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class LibMatrixAgg method getAggType.
private static AggType getAggType(AggregateUnaryOperator op) {
ValueFunction vfn = op.aggOp.increOp.fn;
IndexFunction ifn = op.indexFn;
// (kahan) sum / sum squared / trace (for ReduceDiag)
if (vfn instanceof KahanFunction && (op.aggOp.correctionLocation == CorrectionLocationType.LASTCOLUMN || op.aggOp.correctionLocation == CorrectionLocationType.LASTROW) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow || ifn instanceof ReduceDiag)) {
if (vfn instanceof KahanPlus)
return AggType.KAHAN_SUM;
else if (vfn instanceof KahanPlusSq)
return AggType.KAHAN_SUM_SQ;
}
// mean
if (vfn instanceof Mean && (op.aggOp.correctionLocation == CorrectionLocationType.LASTTWOCOLUMNS || op.aggOp.correctionLocation == CorrectionLocationType.LASTTWOROWS) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
return AggType.MEAN;
}
// variance
if (vfn instanceof CM && ((CM) vfn).getAggOpType() == AggregateOperationTypes.VARIANCE && (op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS || op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
return AggType.VAR;
}
// prod
if (vfn instanceof Multiply && ifn instanceof ReduceAll) {
return AggType.PROD;
}
// min / max
if (vfn instanceof Builtin && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
BuiltinCode bfcode = ((Builtin) vfn).bFunc;
switch(bfcode) {
case MAX:
return AggType.MAX;
case MIN:
return AggType.MIN;
case MAXINDEX:
return AggType.MAX_INDEX;
case MININDEX:
return AggType.MIN_INDEX;
// do nothing
default:
}
}
return AggType.INVALID;
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class CMCOVMRReducer method reduce.
@Override
public void reduce(TaggedFirstSecondIndexes index, Iterator<MatrixValue> values, OutputCollector<MatrixIndexes, MatrixValue> out, Reporter report) throws IOException {
commonSetup(report);
cmNcovCell.setCM_N_COVObject(0, 0, 0);
ValueFunction fn = cmFn.get(index.getTag());
if (covTags.contains(index.getTag()))
fn = covFn;
while (values.hasNext()) {
CM_N_COVCell cell = (CM_N_COVCell) values.next();
try {
fn.execute(cmNcovCell.getCM_N_COVObject(), cell.getCM_N_COVObject());
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
}
// add 0 values back in
/* long totaln=rlens.get(index.getTag())*clens.get(index.getTag());
long zerosToAdd=totaln-(long)(cmNcovCell.getCM_N_COVObject().w);
for(long i=0; i<zerosToAdd; i++)
{
try {
fn.execute(cmNcovCell.getCM_N_COVObject(), zeroObj);
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
}*/
long totaln = rlens.get(index.getTag()) * clens.get(index.getTag());
long zerosToAdd = totaln - (long) (cmNcovCell.getCM_N_COVObject().w);
if (zerosToAdd > 0) {
zeroObj.w = zerosToAdd;
try {
fn.execute(cmNcovCell.getCM_N_COVObject(), zeroObj);
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
}
for (CM_N_COVInstruction in : cmNcovInstructions) {
if (in.input == index.getTag()) {
try {
outCell.setValue(cmNcovCell.getCM_N_COVObject().getRequiredResult(in.getOperator()));
} catch (DMLRuntimeException e) {
throw new IOException(e);
}
ArrayList<Integer> outputIndexes = outputIndexesMapping.get(in.output);
for (int i : outputIndexes) {
collectOutput_N_Increase_Counter(outIndex, outCell, i, report);
}
}
}
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class CMOperator method setCMAggOp.
public CMOperator setCMAggOp(int order) {
AggregateOperationTypes agg = getCMAggOpType(order);
ValueFunction fn = CM.getCMFnObject(aggOpType);
return new CMOperator(fn, agg);
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class BuiltinBinarySPInstruction method parseInstruction.
public static BuiltinBinarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String opcode = null;
boolean isBroadcast = false;
VectorType vtype = null;
ValueFunction func = null;
if (//map builtin function
str.startsWith("SPARK" + Lop.OPERAND_DELIMITOR + "map")) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 5);
opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
func = Builtin.getBuiltinFnObject(opcode.substring(3));
vtype = VectorType.valueOf(parts[5]);
isBroadcast = true;
} else //default builtin function
{
opcode = parseBinaryInstruction(str, in1, in2, out);
func = Builtin.getBuiltinFnObject(opcode);
}
//sanity check value function
if (func == null)
throw new DMLRuntimeException("Failed to create builtin value function for opcode: " + opcode);
// Determine appropriate Function Object based on opcode
if (//MATRIX-SCALAR
in1.getDataType() != in2.getDataType()) {
return new MatrixScalarBuiltinSPInstruction(new RightScalarOperator(func, 0), in1, in2, out, opcode, str);
} else //MATRIX-MATRIX
{
if (isBroadcast)
return new MatrixBVectorBuiltinSPInstruction(new BinaryOperator(func), in1, in2, out, vtype, opcode, str);
else
return new MatrixMatrixBuiltinSPInstruction(new BinaryOperator(func), in1, in2, out, opcode, str);
}
}
use of org.apache.sysml.runtime.functionobjects.ValueFunction in project incubator-systemml by apache.
the class BuiltinUnarySPInstruction method parseInstruction.
public static BuiltinUnarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
String opcode = parseUnaryInstruction(str, in, out);
ValueFunction func = Builtin.getBuiltinFnObject(opcode);
return new MatrixBuiltinSPInstruction(new UnaryOperator(func), in, out, opcode, str);
}
Aggregations