use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.
the class InstructionUtils method parseBasicAggregateUnaryOperator.
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) {
AggregateUnaryOperator aggun = null;
if (opcode.equalsIgnoreCase("uak+")) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uark+")) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uack+")) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if (opcode.equalsIgnoreCase("uasqk+")) {
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uarsqk+")) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uacsqk+")) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if (opcode.equalsIgnoreCase("uamean")) {
// Mean
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uarmean")) {
// RowMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uacmean")) {
// ColMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOROWS);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if (opcode.equalsIgnoreCase("uavar")) {
// Variance
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uarvar")) {
// RowVariances
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uacvar")) {
// ColVariances
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
CorrectionLocationType cloc = CorrectionLocationType.LASTFOURROWS;
AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if (opcode.equalsIgnoreCase("ua+")) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uar+")) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uac+")) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if (opcode.equalsIgnoreCase("ua*")) {
AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uamax")) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uamin")) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
} else if (opcode.equalsIgnoreCase("uatrace")) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
} else if (opcode.equalsIgnoreCase("uaktrace")) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
} else if (opcode.equalsIgnoreCase("uarmax")) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uarimax")) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uarmin")) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uarimin")) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
} else if (opcode.equalsIgnoreCase("uacmax")) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if (opcode.equalsIgnoreCase("uacmin")) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
return aggun;
}
use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.
the class InstructionUtils method parseAggregateOperator.
public static AggregateOperator parseAggregateOperator(String opcode, String corrExists, String corrLoc) {
AggregateOperator agg = null;
if (opcode.equalsIgnoreCase("ak+") || opcode.equalsIgnoreCase("aktrace")) {
boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
} else if (opcode.equalsIgnoreCase("asqk+")) {
boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), lcorrExists, lcorrLoc);
} else if (opcode.equalsIgnoreCase("a+")) {
agg = new AggregateOperator(0, Plus.getPlusFnObject());
} else if (opcode.equalsIgnoreCase("a*")) {
agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
} else if (opcode.equalsIgnoreCase("arimax")) {
agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
} else if (opcode.equalsIgnoreCase("amax")) {
agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
} else if (opcode.equalsIgnoreCase("amin")) {
agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
} else if (opcode.equalsIgnoreCase("arimin")) {
agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
} else if (opcode.equalsIgnoreCase("amean")) {
boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTTWOCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
} else if (opcode.equalsIgnoreCase("avar")) {
boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTFOURCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
agg = new AggregateOperator(0, varFn, lcorrExists, lcorrLoc);
}
return agg;
}
use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.
the class LibMatrixAgg method aggregateUnaryMatrixDense.
private static void aggregateUnaryMatrixDense(MatrixBlock in, MatrixBlock out, AggType optype, ValueFunction vFn, IndexFunction ixFn, int rl, int ru) throws DMLRuntimeException {
final int m = in.rlen;
final int n = in.clen;
double[] a = in.getDenseBlock();
double[] c = out.getDenseBlock();
switch(optype) {
case //SUM/TRACE via k+,
KAHAN_SUM:
{
KahanObject kbuff = new KahanObject(0, 0);
if (// SUM
ixFn instanceof ReduceAll)
d_uakp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
else if (//ROWSUM
ixFn instanceof ReduceCol)
d_uarkp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
else if (//COLSUM
ixFn instanceof ReduceRow)
d_uackp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
else if (//TRACE
ixFn instanceof ReduceDiag)
d_uakptrace(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
break;
}
case //SUM_SQ via k+,
KAHAN_SUM_SQ:
{
KahanObject kbuff = new KahanObject(0, 0);
if (//SUM_SQ
ixFn instanceof ReduceAll)
d_uasqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
else if (//ROWSUM_SQ
ixFn instanceof ReduceCol)
d_uarsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
else if (//COLSUM_SQ
ixFn instanceof ReduceRow)
d_uacsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
break;
}
case //CUMSUM
CUM_KAHAN_SUM:
{
KahanObject kbuff = new KahanObject(0, 0);
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
d_ucumkp(a, null, c, m, n, kbuff, kplus, rl, ru);
break;
}
case //CUMPROD
CUM_PROD:
{
d_ucumm(a, null, c, m, n, rl, ru);
break;
}
case CUM_MIN:
case CUM_MAX:
{
double init = Double.MAX_VALUE * ((optype == AggType.CUM_MAX) ? -1 : 1);
d_ucummxx(a, null, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MIN:
case //MAX/MIN
MAX:
{
double init = Double.MAX_VALUE * ((optype == AggType.MAX) ? -1 : 1);
if (// MIN/MAX
ixFn instanceof ReduceAll)
d_uamxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
else if (//ROWMIN/ROWMAX
ixFn instanceof ReduceCol)
d_uarmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
else if (//COLMIN/COLMAX
ixFn instanceof ReduceRow)
d_uacmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MAX_INDEX:
{
double init = -Double.MAX_VALUE;
if (//ROWINDEXMAX
ixFn instanceof ReduceCol)
d_uarimxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MIN_INDEX:
{
double init = Double.MAX_VALUE;
if (//ROWINDEXMIN
ixFn instanceof ReduceCol)
d_uarimin(a, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case //MEAN
MEAN:
{
KahanObject kbuff = new KahanObject(0, 0);
if (// MEAN
ixFn instanceof ReduceAll)
d_uamean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
else if (//ROWMEAN
ixFn instanceof ReduceCol)
d_uarmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
else if (//COLMEAN
ixFn instanceof ReduceRow)
d_uacmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
break;
}
case //VAR
VAR:
{
CM_COV_Object cbuff = new CM_COV_Object();
if (//VAR
ixFn instanceof ReduceAll)
d_uavar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
else if (//ROWVAR
ixFn instanceof ReduceCol)
d_uarvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
else if (//COLVAR
ixFn instanceof ReduceRow)
d_uacvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
break;
}
case //PROD
PROD:
{
if (// PROD
ixFn instanceof ReduceAll)
d_uam(a, c, m, n, rl, ru);
break;
}
default:
throw new DMLRuntimeException("Unsupported aggregation type: " + optype);
}
}
use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.
the class LibMatrixAgg method aggregateUnaryMatrixSparse.
private static void aggregateUnaryMatrixSparse(MatrixBlock in, MatrixBlock out, AggType optype, ValueFunction vFn, IndexFunction ixFn, int rl, int ru) throws DMLRuntimeException {
final int m = in.rlen;
final int n = in.clen;
SparseBlock a = in.getSparseBlock();
double[] c = out.getDenseBlock();
switch(optype) {
case //SUM via k+
KAHAN_SUM:
{
KahanObject kbuff = new KahanObject(0, 0);
if (// SUM
ixFn instanceof ReduceAll)
s_uakp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
else if (//ROWSUM
ixFn instanceof ReduceCol)
s_uarkp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
else if (//COLSUM
ixFn instanceof ReduceRow)
s_uackp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
else if (//TRACE
ixFn instanceof ReduceDiag)
s_uakptrace(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
break;
}
case //SUM_SQ via k+
KAHAN_SUM_SQ:
{
KahanObject kbuff = new KahanObject(0, 0);
if (//SUM_SQ
ixFn instanceof ReduceAll)
s_uasqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
else if (//ROWSUM_SQ
ixFn instanceof ReduceCol)
s_uarsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
else if (//COLSUM_SQ
ixFn instanceof ReduceRow)
s_uacsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
break;
}
case //CUMSUM
CUM_KAHAN_SUM:
{
KahanObject kbuff = new KahanObject(0, 0);
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
s_ucumkp(a, null, c, m, n, kbuff, kplus, rl, ru);
break;
}
case //CUMPROD
CUM_PROD:
{
s_ucumm(a, null, c, m, n, rl, ru);
break;
}
case CUM_MIN:
case CUM_MAX:
{
double init = Double.MAX_VALUE * ((optype == AggType.CUM_MAX) ? -1 : 1);
s_ucummxx(a, null, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MIN:
case //MAX/MIN
MAX:
{
double init = Double.MAX_VALUE * ((optype == AggType.MAX) ? -1 : 1);
if (// MIN/MAX
ixFn instanceof ReduceAll)
s_uamxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
else if (//ROWMIN/ROWMAX
ixFn instanceof ReduceCol)
s_uarmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
else if (//COLMIN/COLMAX
ixFn instanceof ReduceRow)
s_uacmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MAX_INDEX:
{
double init = -Double.MAX_VALUE;
if (//ROWINDEXMAX
ixFn instanceof ReduceCol)
s_uarimxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MIN_INDEX:
{
double init = Double.MAX_VALUE;
if (//ROWINDEXMAX
ixFn instanceof ReduceCol)
s_uarimin(a, c, m, n, init, (Builtin) vFn, rl, ru);
break;
}
case MEAN:
{
KahanObject kbuff = new KahanObject(0, 0);
if (// MEAN
ixFn instanceof ReduceAll)
s_uamean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
else if (//ROWMEAN
ixFn instanceof ReduceCol)
s_uarmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
else if (//COLMEAN
ixFn instanceof ReduceRow)
s_uacmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
break;
}
case //VAR
VAR:
{
CM_COV_Object cbuff = new CM_COV_Object();
if (//VAR
ixFn instanceof ReduceAll)
s_uavar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
else if (//ROWVAR
ixFn instanceof ReduceCol)
s_uarvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
else if (//COLVAR
ixFn instanceof ReduceRow)
s_uacvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
break;
}
case //PROD
PROD:
{
if (// PROD
ixFn instanceof ReduceAll)
s_uam(a, c, m, n, rl, ru);
break;
}
default:
throw new DMLRuntimeException("Unsupported aggregation type: " + optype);
}
}
use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.
the class LibMatrixCUDA method unaryAggregate.
//********************************************************************/
//***************** END OF MATRIX MULTIPLY Functions *****************/
//********************************************************************/
//********************************************************************/
//**************** UNARY AGGREGATE Functions ************************/
//********************************************************************/
/**
* Entry point to perform Unary aggregate operations on the GPU.
* The execution context object is used to allocate memory for the GPU.
* @param ec Instance of {@link ExecutionContext}, from which the output variable will be allocated
* @param gCtx a valid {@link GPUContext}
* @param instName name of the invoking instruction to record{@link Statistics}.
* @param in1 input matrix
* @param output output matrix/scalar name
* @param op Instance of {@link AggregateUnaryOperator} which encapsulates the direction of reduction/aggregation and the reduction operation.
* @throws DMLRuntimeException if {@link DMLRuntimeException} occurs
*/
public static void unaryAggregate(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String output, AggregateUnaryOperator op) throws DMLRuntimeException {
if (ec.getGPUContext() != gCtx)
throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
LOG.trace("GPU : unaryAggregate" + ", GPUContext=" + gCtx);
final int REDUCTION_ALL = 1;
final int REDUCTION_ROW = 2;
final int REDUCTION_COL = 3;
final int REDUCTION_DIAG = 4;
// A kahan sum implemention is not provided. is a "uak+" or other kahan operator is encountered,
// it just does regular summation reduction.
final int OP_PLUS = 1;
final int OP_PLUS_SQ = 2;
final int OP_MEAN = 3;
final int OP_VARIANCE = 4;
final int OP_MULTIPLY = 5;
final int OP_MAX = 6;
final int OP_MIN = 7;
final int OP_MAXINDEX = 8;
final int OP_MININDEX = 9;
// Sanity Checks
if (!in1.getGPUObject(gCtx).isAllocated())
throw new DMLRuntimeException("Internal Error - The input is not allocated for a GPU Aggregate Unary:" + in1.getGPUObject(gCtx).isAllocated());
boolean isSparse = in1.getGPUObject(gCtx).isSparse();
IndexFunction indexFn = op.indexFn;
AggregateOperator aggOp = op.aggOp;
// Convert Reduction direction to a number to pass to CUDA kernel
int reductionDirection = -1;
if (indexFn instanceof ReduceAll) {
reductionDirection = REDUCTION_ALL;
} else if (indexFn instanceof ReduceRow) {
reductionDirection = REDUCTION_ROW;
} else if (indexFn instanceof ReduceCol) {
reductionDirection = REDUCTION_COL;
} else if (indexFn instanceof ReduceDiag) {
reductionDirection = REDUCTION_DIAG;
} else {
throw new DMLRuntimeException("Internal Error - Invalid index function type, only reducing along rows, columns, diagonals or all elements is supported in Aggregate Unary operations");
}
assert reductionDirection != -1 : "Internal Error - Incorrect type of reduction direction set for aggregate unary GPU instruction";
// Convert function type to a number to pass to the CUDA Kernel
int opIndex = -1;
if (aggOp.increOp.fn instanceof KahanPlus) {
opIndex = OP_PLUS;
} else if (aggOp.increOp.fn instanceof KahanPlusSq) {
opIndex = OP_PLUS_SQ;
} else if (aggOp.increOp.fn instanceof Mean) {
opIndex = OP_MEAN;
} else if (aggOp.increOp.fn instanceof CM) {
assert ((CM) aggOp.increOp.fn).getAggOpType() == CMOperator.AggregateOperationTypes.VARIANCE : "Internal Error - Invalid Type of CM operator for Aggregate Unary operation on GPU";
opIndex = OP_VARIANCE;
} else if (aggOp.increOp.fn instanceof Plus) {
opIndex = OP_PLUS;
} else if (aggOp.increOp.fn instanceof Multiply) {
opIndex = OP_MULTIPLY;
} else if (aggOp.increOp.fn instanceof Builtin) {
Builtin b = (Builtin) aggOp.increOp.fn;
switch(b.bFunc) {
case MAX:
opIndex = OP_MAX;
break;
case MIN:
opIndex = OP_MIN;
break;
case MAXINDEX:
opIndex = OP_MAXINDEX;
break;
case MININDEX:
opIndex = OP_MININDEX;
break;
default:
new DMLRuntimeException("Internal Error - Unsupported Builtin Function for Aggregate unary being done on GPU");
}
} else {
throw new DMLRuntimeException("Internal Error - Aggregate operator has invalid Value function");
}
assert opIndex != -1 : "Internal Error - Incorrect type of operation set for aggregate unary GPU instruction";
int rlen = (int) in1.getNumRows();
int clen = (int) in1.getNumColumns();
if (isSparse) {
// The strategy for the time being is to convert sparse to dense
// until a sparse specific kernel is written.
in1.getGPUObject(gCtx).sparseToDense(instName);
// long nnz = in1.getNnz();
// assert nnz > 0 : "Internal Error - number of non zeroes set to " + nnz + " in Aggregate Binary for GPU";
// MatrixObject out = ec.getSparseMatrixOutputForGPUInstruction(output, nnz);
// throw new DMLRuntimeException("Internal Error - Not implemented");
}
Pointer out = null;
if (reductionDirection == REDUCTION_COL || reductionDirection == REDUCTION_ROW) {
// Matrix output
MatrixObject out1 = getDenseMatrixOutputForGPUInstruction(ec, instName, output);
out = getDensePointer(gCtx, out1, instName);
}
Pointer in = getDensePointer(gCtx, in1, instName);
int size = rlen * clen;
// For scalars, set the scalar output in the Execution Context object
switch(opIndex) {
case OP_PLUS:
{
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
ec.setScalarOutput(output, new DoubleObject(result));
break;
}
case REDUCTION_COL:
{
// The names are a bit misleading, REDUCTION_COL refers to the direction (reduce all elements in a column)
reduceRow(gCtx, instName, "reduce_row_sum", in, out, rlen, clen);
break;
}
case REDUCTION_ROW:
{
reduceCol(gCtx, instName, "reduce_col_sum", in, out, rlen, clen);
break;
}
case REDUCTION_DIAG:
throw new DMLRuntimeException("Internal Error - Row, Column and Diag summation not implemented yet");
}
break;
}
case OP_PLUS_SQ:
{
// Calculate the squares in a temporary object tmp
Pointer tmp = gCtx.allocate(instName, size * Sizeof.DOUBLE);
squareMatrix(gCtx, instName, in, tmp, rlen, clen);
// Then do the sum on the temporary object and free it
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_sum", tmp, size);
ec.setScalarOutput(output, new DoubleObject(result));
break;
}
case REDUCTION_COL:
{
// The names are a bit misleading, REDUCTION_COL refers to the direction (reduce all elements in a column)
reduceRow(gCtx, instName, "reduce_row_sum", tmp, out, rlen, clen);
break;
}
case REDUCTION_ROW:
{
reduceCol(gCtx, instName, "reduce_col_sum", tmp, out, rlen, clen);
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for summation squared");
}
gCtx.cudaFreeHelper(instName, tmp);
break;
}
case OP_MEAN:
{
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
double mean = result / size;
ec.setScalarOutput(output, new DoubleObject(mean));
break;
}
case REDUCTION_COL:
{
reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
break;
}
case REDUCTION_ROW:
{
reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for mean");
}
break;
}
case OP_MULTIPLY:
{
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_prod", in, size);
ec.setScalarOutput(output, new DoubleObject(result));
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for multiplication");
}
break;
}
case OP_MAX:
{
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_max", in, size);
ec.setScalarOutput(output, new DoubleObject(result));
break;
}
case REDUCTION_COL:
{
reduceRow(gCtx, instName, "reduce_row_max", in, out, rlen, clen);
break;
}
case REDUCTION_ROW:
{
reduceCol(gCtx, instName, "reduce_col_max", in, out, rlen, clen);
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for max");
}
break;
}
case OP_MIN:
{
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_min", in, size);
ec.setScalarOutput(output, new DoubleObject(result));
break;
}
case REDUCTION_COL:
{
reduceRow(gCtx, instName, "reduce_row_min", in, out, rlen, clen);
break;
}
case REDUCTION_ROW:
{
reduceCol(gCtx, instName, "reduce_col_min", in, out, rlen, clen);
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for min");
}
break;
}
case OP_VARIANCE:
{
// Temporary GPU array for
Pointer tmp = gCtx.allocate(instName, size * Sizeof.DOUBLE);
Pointer tmp2 = gCtx.allocate(instName, size * Sizeof.DOUBLE);
switch(reductionDirection) {
case REDUCTION_ALL:
{
double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
double mean = result / size;
// Subtract mean from every element in the matrix
ScalarOperator minusOp = new RightScalarOperator(Minus.getMinusFnObject(), mean);
matrixScalarOp(gCtx, instName, in, mean, rlen, clen, tmp, minusOp);
squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
double result2 = reduceAll(gCtx, instName, "reduce_sum", tmp2, size);
double variance = result2 / (size - 1);
ec.setScalarOutput(output, new DoubleObject(variance));
break;
}
case REDUCTION_COL:
{
reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
// Subtract the row-wise mean from every element in the matrix
BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
Pointer tmpRow = gCtx.allocate(instName, rlen * Sizeof.DOUBLE);
reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
gCtx.cudaFreeHelper(instName, tmpRow);
break;
}
case REDUCTION_ROW:
{
reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
// Subtract the columns-wise mean from every element in the matrix
BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
Pointer tmpCol = gCtx.allocate(instName, clen * Sizeof.DOUBLE);
reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
gCtx.cudaFreeHelper(instName, tmpCol);
break;
}
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance");
}
gCtx.cudaFreeHelper(instName, tmp);
gCtx.cudaFreeHelper(instName, tmp2);
break;
}
case OP_MAXINDEX:
{
switch(reductionDirection) {
case REDUCTION_COL:
throw new DMLRuntimeException("Internal Error - Column maxindex of matrix not implemented yet for GPU ");
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for maxindex");
}
// break;
}
case OP_MININDEX:
{
switch(reductionDirection) {
case REDUCTION_COL:
throw new DMLRuntimeException("Internal Error - Column minindex of matrix not implemented yet for GPU ");
default:
throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for minindex");
}
// break;
}
default:
throw new DMLRuntimeException("Internal Error - Invalid GPU Unary aggregate function!");
}
}
Aggregations