Search in sources :

Example 1 with CM

use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.

the class InstructionUtils method parseBasicAggregateUnaryOperator.

public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) {
    AggregateUnaryOperator aggun = null;
    if (opcode.equalsIgnoreCase("uak+")) {
        AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uark+")) {
        // RowSums
        AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uack+")) {
        // ColSums
        AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    } else if (opcode.equalsIgnoreCase("uasqk+")) {
        AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uarsqk+")) {
        // RowSums
        AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uacsqk+")) {
        // ColSums
        AggregateOperator agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), true, CorrectionLocationType.LASTROW);
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    } else if (opcode.equalsIgnoreCase("uamean")) {
        // Mean
        AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uarmean")) {
        // RowMeans
        AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uacmean")) {
        // ColMeans
        AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOROWS);
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    } else if (opcode.equalsIgnoreCase("uavar")) {
        // Variance
        CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
        CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
        AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uarvar")) {
        // RowVariances
        CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
        CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
        AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uacvar")) {
        // ColVariances
        CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
        CorrectionLocationType cloc = CorrectionLocationType.LASTFOURROWS;
        AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    } else if (opcode.equalsIgnoreCase("ua+")) {
        AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uar+")) {
        // RowSums
        AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uac+")) {
        // ColSums
        AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    } else if (opcode.equalsIgnoreCase("ua*")) {
        AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uamax")) {
        AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uamin")) {
        AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
        aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
    } else if (opcode.equalsIgnoreCase("uatrace")) {
        AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
        aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
    } else if (opcode.equalsIgnoreCase("uaktrace")) {
        AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
    } else if (opcode.equalsIgnoreCase("uarmax")) {
        AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uarimax")) {
        AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uarmin")) {
        AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uarimin")) {
        AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
        aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
    } else if (opcode.equalsIgnoreCase("uacmax")) {
        AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    } else if (opcode.equalsIgnoreCase("uacmin")) {
        AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
        aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
    }
    return aggun;
}
Also used : AggregateUnaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) CM(org.apache.sysml.runtime.functionobjects.CM) CorrectionLocationType(org.apache.sysml.lops.PartialAggregate.CorrectionLocationType)

Example 2 with CM

use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.

the class InstructionUtils method parseAggregateOperator.

public static AggregateOperator parseAggregateOperator(String opcode, String corrExists, String corrLoc) {
    AggregateOperator agg = null;
    if (opcode.equalsIgnoreCase("ak+") || opcode.equalsIgnoreCase("aktrace")) {
        boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
        CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
        agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
    } else if (opcode.equalsIgnoreCase("asqk+")) {
        boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
        CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
        agg = new AggregateOperator(0, KahanPlusSq.getKahanPlusSqFnObject(), lcorrExists, lcorrLoc);
    } else if (opcode.equalsIgnoreCase("a+")) {
        agg = new AggregateOperator(0, Plus.getPlusFnObject());
    } else if (opcode.equalsIgnoreCase("a*")) {
        agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
    } else if (opcode.equalsIgnoreCase("arimax")) {
        agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
    } else if (opcode.equalsIgnoreCase("amax")) {
        agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
    } else if (opcode.equalsIgnoreCase("amin")) {
        agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
    } else if (opcode.equalsIgnoreCase("arimin")) {
        agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
    } else if (opcode.equalsIgnoreCase("amean")) {
        boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
        CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTTWOCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
        agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
    } else if (opcode.equalsIgnoreCase("avar")) {
        boolean lcorrExists = (corrExists == null) ? true : Boolean.parseBoolean(corrExists);
        CorrectionLocationType lcorrLoc = (corrLoc == null) ? CorrectionLocationType.LASTFOURCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
        CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
        agg = new AggregateOperator(0, varFn, lcorrExists, lcorrLoc);
    }
    return agg;
}
Also used : AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) CM(org.apache.sysml.runtime.functionobjects.CM) CorrectionLocationType(org.apache.sysml.lops.PartialAggregate.CorrectionLocationType)

Example 3 with CM

use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.

the class LibMatrixAgg method aggregateUnaryMatrixDense.

private static void aggregateUnaryMatrixDense(MatrixBlock in, MatrixBlock out, AggType optype, ValueFunction vFn, IndexFunction ixFn, int rl, int ru) throws DMLRuntimeException {
    final int m = in.rlen;
    final int n = in.clen;
    double[] a = in.getDenseBlock();
    double[] c = out.getDenseBlock();
    switch(optype) {
        case //SUM/TRACE via k+, 
        KAHAN_SUM:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                if (// SUM
                ixFn instanceof ReduceAll)
                    d_uakp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                else if (//ROWSUM
                ixFn instanceof ReduceCol)
                    d_uarkp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                else if (//COLSUM
                ixFn instanceof ReduceRow)
                    d_uackp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                else if (//TRACE
                ixFn instanceof ReduceDiag)
                    d_uakptrace(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                break;
            }
        case //SUM_SQ via k+,
        KAHAN_SUM_SQ:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                if (//SUM_SQ
                ixFn instanceof ReduceAll)
                    d_uasqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
                else if (//ROWSUM_SQ
                ixFn instanceof ReduceCol)
                    d_uarsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
                else if (//COLSUM_SQ
                ixFn instanceof ReduceRow)
                    d_uacsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
                break;
            }
        case //CUMSUM
        CUM_KAHAN_SUM:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                d_ucumkp(a, null, c, m, n, kbuff, kplus, rl, ru);
                break;
            }
        case //CUMPROD
        CUM_PROD:
            {
                d_ucumm(a, null, c, m, n, rl, ru);
                break;
            }
        case CUM_MIN:
        case CUM_MAX:
            {
                double init = Double.MAX_VALUE * ((optype == AggType.CUM_MAX) ? -1 : 1);
                d_ucummxx(a, null, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MIN:
        case //MAX/MIN
        MAX:
            {
                double init = Double.MAX_VALUE * ((optype == AggType.MAX) ? -1 : 1);
                if (// MIN/MAX
                ixFn instanceof ReduceAll)
                    d_uamxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                else if (//ROWMIN/ROWMAX
                ixFn instanceof ReduceCol)
                    d_uarmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                else if (//COLMIN/COLMAX
                ixFn instanceof ReduceRow)
                    d_uacmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MAX_INDEX:
            {
                double init = -Double.MAX_VALUE;
                if (//ROWINDEXMAX
                ixFn instanceof ReduceCol)
                    d_uarimxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MIN_INDEX:
            {
                double init = Double.MAX_VALUE;
                if (//ROWINDEXMIN
                ixFn instanceof ReduceCol)
                    d_uarimin(a, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case //MEAN
        MEAN:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                if (// MEAN
                ixFn instanceof ReduceAll)
                    d_uamean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
                else if (//ROWMEAN
                ixFn instanceof ReduceCol)
                    d_uarmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
                else if (//COLMEAN
                ixFn instanceof ReduceRow)
                    d_uacmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
                break;
            }
        case //VAR
        VAR:
            {
                CM_COV_Object cbuff = new CM_COV_Object();
                if (//VAR
                ixFn instanceof ReduceAll)
                    d_uavar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
                else if (//ROWVAR
                ixFn instanceof ReduceCol)
                    d_uarvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
                else if (//COLVAR
                ixFn instanceof ReduceRow)
                    d_uacvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
                break;
            }
        case //PROD
        PROD:
            {
                if (// PROD
                ixFn instanceof ReduceAll)
                    d_uam(a, c, m, n, rl, ru);
                break;
            }
        default:
            throw new DMLRuntimeException("Unsupported aggregation type: " + optype);
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) Mean(org.apache.sysml.runtime.functionobjects.Mean) ReduceDiag(org.apache.sysml.runtime.functionobjects.ReduceDiag) CM(org.apache.sysml.runtime.functionobjects.CM) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 4 with CM

use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.

the class LibMatrixAgg method aggregateUnaryMatrixSparse.

private static void aggregateUnaryMatrixSparse(MatrixBlock in, MatrixBlock out, AggType optype, ValueFunction vFn, IndexFunction ixFn, int rl, int ru) throws DMLRuntimeException {
    final int m = in.rlen;
    final int n = in.clen;
    SparseBlock a = in.getSparseBlock();
    double[] c = out.getDenseBlock();
    switch(optype) {
        case //SUM via k+
        KAHAN_SUM:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                if (// SUM
                ixFn instanceof ReduceAll)
                    s_uakp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                else if (//ROWSUM
                ixFn instanceof ReduceCol)
                    s_uarkp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                else if (//COLSUM
                ixFn instanceof ReduceRow)
                    s_uackp(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                else if (//TRACE
                ixFn instanceof ReduceDiag)
                    s_uakptrace(a, c, m, n, kbuff, (KahanPlus) vFn, rl, ru);
                break;
            }
        case //SUM_SQ via k+
        KAHAN_SUM_SQ:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                if (//SUM_SQ
                ixFn instanceof ReduceAll)
                    s_uasqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
                else if (//ROWSUM_SQ
                ixFn instanceof ReduceCol)
                    s_uarsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
                else if (//COLSUM_SQ
                ixFn instanceof ReduceRow)
                    s_uacsqkp(a, c, m, n, kbuff, (KahanPlusSq) vFn, rl, ru);
                break;
            }
        case //CUMSUM
        CUM_KAHAN_SUM:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
                s_ucumkp(a, null, c, m, n, kbuff, kplus, rl, ru);
                break;
            }
        case //CUMPROD
        CUM_PROD:
            {
                s_ucumm(a, null, c, m, n, rl, ru);
                break;
            }
        case CUM_MIN:
        case CUM_MAX:
            {
                double init = Double.MAX_VALUE * ((optype == AggType.CUM_MAX) ? -1 : 1);
                s_ucummxx(a, null, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MIN:
        case //MAX/MIN
        MAX:
            {
                double init = Double.MAX_VALUE * ((optype == AggType.MAX) ? -1 : 1);
                if (// MIN/MAX
                ixFn instanceof ReduceAll)
                    s_uamxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                else if (//ROWMIN/ROWMAX
                ixFn instanceof ReduceCol)
                    s_uarmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                else if (//COLMIN/COLMAX
                ixFn instanceof ReduceRow)
                    s_uacmxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MAX_INDEX:
            {
                double init = -Double.MAX_VALUE;
                if (//ROWINDEXMAX
                ixFn instanceof ReduceCol)
                    s_uarimxx(a, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MIN_INDEX:
            {
                double init = Double.MAX_VALUE;
                if (//ROWINDEXMAX
                ixFn instanceof ReduceCol)
                    s_uarimin(a, c, m, n, init, (Builtin) vFn, rl, ru);
                break;
            }
        case MEAN:
            {
                KahanObject kbuff = new KahanObject(0, 0);
                if (// MEAN
                ixFn instanceof ReduceAll)
                    s_uamean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
                else if (//ROWMEAN
                ixFn instanceof ReduceCol)
                    s_uarmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
                else if (//COLMEAN
                ixFn instanceof ReduceRow)
                    s_uacmean(a, c, m, n, kbuff, (Mean) vFn, rl, ru);
                break;
            }
        case //VAR
        VAR:
            {
                CM_COV_Object cbuff = new CM_COV_Object();
                if (//VAR
                ixFn instanceof ReduceAll)
                    s_uavar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
                else if (//ROWVAR
                ixFn instanceof ReduceCol)
                    s_uarvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
                else if (//COLVAR
                ixFn instanceof ReduceRow)
                    s_uacvar(a, c, m, n, cbuff, (CM) vFn, rl, ru);
                break;
            }
        case //PROD
        PROD:
            {
                if (// PROD
                ixFn instanceof ReduceAll)
                    s_uam(a, c, m, n, rl, ru);
                break;
            }
        default:
            throw new DMLRuntimeException("Unsupported aggregation type: " + optype);
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) Mean(org.apache.sysml.runtime.functionobjects.Mean) ReduceDiag(org.apache.sysml.runtime.functionobjects.ReduceDiag) CM(org.apache.sysml.runtime.functionobjects.CM) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 5 with CM

use of org.apache.sysml.runtime.functionobjects.CM in project incubator-systemml by apache.

the class LibMatrixCUDA method unaryAggregate.

//********************************************************************/
//***************** END OF MATRIX MULTIPLY Functions *****************/
//********************************************************************/
//********************************************************************/
//****************  UNARY AGGREGATE Functions ************************/
//********************************************************************/
/**
	 * Entry point to perform Unary aggregate operations on the GPU.
	 * The execution context object is used to allocate memory for the GPU.
	 * @param ec			Instance of {@link ExecutionContext}, from which the output variable will be allocated
	 * @param gCtx    a valid {@link GPUContext}
	 * @param instName name of the invoking instruction to record{@link Statistics}.
	 * @param in1			input matrix
	 * @param output	output matrix/scalar name
	 * @param op			Instance of {@link AggregateUnaryOperator} which encapsulates the direction of reduction/aggregation and the reduction operation.
	 * @throws DMLRuntimeException if {@link DMLRuntimeException} occurs
	 */
public static void unaryAggregate(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String output, AggregateUnaryOperator op) throws DMLRuntimeException {
    if (ec.getGPUContext() != gCtx)
        throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
    LOG.trace("GPU : unaryAggregate" + ", GPUContext=" + gCtx);
    final int REDUCTION_ALL = 1;
    final int REDUCTION_ROW = 2;
    final int REDUCTION_COL = 3;
    final int REDUCTION_DIAG = 4;
    // A kahan sum implemention is not provided. is a "uak+" or other kahan operator is encountered,
    // it just does regular summation reduction.
    final int OP_PLUS = 1;
    final int OP_PLUS_SQ = 2;
    final int OP_MEAN = 3;
    final int OP_VARIANCE = 4;
    final int OP_MULTIPLY = 5;
    final int OP_MAX = 6;
    final int OP_MIN = 7;
    final int OP_MAXINDEX = 8;
    final int OP_MININDEX = 9;
    // Sanity Checks
    if (!in1.getGPUObject(gCtx).isAllocated())
        throw new DMLRuntimeException("Internal Error - The input is not allocated for a GPU Aggregate Unary:" + in1.getGPUObject(gCtx).isAllocated());
    boolean isSparse = in1.getGPUObject(gCtx).isSparse();
    IndexFunction indexFn = op.indexFn;
    AggregateOperator aggOp = op.aggOp;
    // Convert Reduction direction to a number to pass to CUDA kernel
    int reductionDirection = -1;
    if (indexFn instanceof ReduceAll) {
        reductionDirection = REDUCTION_ALL;
    } else if (indexFn instanceof ReduceRow) {
        reductionDirection = REDUCTION_ROW;
    } else if (indexFn instanceof ReduceCol) {
        reductionDirection = REDUCTION_COL;
    } else if (indexFn instanceof ReduceDiag) {
        reductionDirection = REDUCTION_DIAG;
    } else {
        throw new DMLRuntimeException("Internal Error - Invalid index function type, only reducing along rows, columns, diagonals or all elements is supported in Aggregate Unary operations");
    }
    assert reductionDirection != -1 : "Internal Error - Incorrect type of reduction direction set for aggregate unary GPU instruction";
    // Convert function type to a number to pass to the CUDA Kernel
    int opIndex = -1;
    if (aggOp.increOp.fn instanceof KahanPlus) {
        opIndex = OP_PLUS;
    } else if (aggOp.increOp.fn instanceof KahanPlusSq) {
        opIndex = OP_PLUS_SQ;
    } else if (aggOp.increOp.fn instanceof Mean) {
        opIndex = OP_MEAN;
    } else if (aggOp.increOp.fn instanceof CM) {
        assert ((CM) aggOp.increOp.fn).getAggOpType() == CMOperator.AggregateOperationTypes.VARIANCE : "Internal Error - Invalid Type of CM operator for Aggregate Unary operation on GPU";
        opIndex = OP_VARIANCE;
    } else if (aggOp.increOp.fn instanceof Plus) {
        opIndex = OP_PLUS;
    } else if (aggOp.increOp.fn instanceof Multiply) {
        opIndex = OP_MULTIPLY;
    } else if (aggOp.increOp.fn instanceof Builtin) {
        Builtin b = (Builtin) aggOp.increOp.fn;
        switch(b.bFunc) {
            case MAX:
                opIndex = OP_MAX;
                break;
            case MIN:
                opIndex = OP_MIN;
                break;
            case MAXINDEX:
                opIndex = OP_MAXINDEX;
                break;
            case MININDEX:
                opIndex = OP_MININDEX;
                break;
            default:
                new DMLRuntimeException("Internal Error - Unsupported Builtin Function for Aggregate unary being done on GPU");
        }
    } else {
        throw new DMLRuntimeException("Internal Error - Aggregate operator has invalid Value function");
    }
    assert opIndex != -1 : "Internal Error - Incorrect type of operation set for aggregate unary GPU instruction";
    int rlen = (int) in1.getNumRows();
    int clen = (int) in1.getNumColumns();
    if (isSparse) {
        // The strategy for the time being is to convert sparse to dense
        // until a sparse specific kernel is written.
        in1.getGPUObject(gCtx).sparseToDense(instName);
    // long nnz = in1.getNnz();
    // assert nnz > 0 : "Internal Error - number of non zeroes set to " + nnz + " in Aggregate Binary for GPU";
    // MatrixObject out = ec.getSparseMatrixOutputForGPUInstruction(output, nnz);
    // throw new DMLRuntimeException("Internal Error - Not implemented");
    }
    Pointer out = null;
    if (reductionDirection == REDUCTION_COL || reductionDirection == REDUCTION_ROW) {
        // Matrix output
        MatrixObject out1 = getDenseMatrixOutputForGPUInstruction(ec, instName, output);
        out = getDensePointer(gCtx, out1, instName);
    }
    Pointer in = getDensePointer(gCtx, in1, instName);
    int size = rlen * clen;
    // For scalars, set the scalar output in the Execution Context object
    switch(opIndex) {
        case OP_PLUS:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            // The names are a bit misleading, REDUCTION_COL refers to the direction (reduce all elements in a column)
                            reduceRow(gCtx, instName, "reduce_row_sum", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_sum", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_DIAG:
                        throw new DMLRuntimeException("Internal Error - Row, Column and Diag summation not implemented yet");
                }
                break;
            }
        case OP_PLUS_SQ:
            {
                // Calculate the squares in a temporary object tmp
                Pointer tmp = gCtx.allocate(instName, size * Sizeof.DOUBLE);
                squareMatrix(gCtx, instName, in, tmp, rlen, clen);
                // Then do the sum on the temporary object and free it
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", tmp, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            // The names are a bit misleading, REDUCTION_COL refers to the direction (reduce all elements in a column)
                            reduceRow(gCtx, instName, "reduce_row_sum", tmp, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_sum", tmp, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for summation squared");
                }
                gCtx.cudaFreeHelper(instName, tmp);
                break;
            }
        case OP_MEAN:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
                            double mean = result / size;
                            ec.setScalarOutput(output, new DoubleObject(mean));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for mean");
                }
                break;
            }
        case OP_MULTIPLY:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_prod", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for multiplication");
                }
                break;
            }
        case OP_MAX:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_max", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_max", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_max", in, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for max");
                }
                break;
            }
        case OP_MIN:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_min", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_min", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_min", in, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for min");
                }
                break;
            }
        case OP_VARIANCE:
            {
                // Temporary GPU array for
                Pointer tmp = gCtx.allocate(instName, size * Sizeof.DOUBLE);
                Pointer tmp2 = gCtx.allocate(instName, size * Sizeof.DOUBLE);
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
                            double mean = result / size;
                            // Subtract mean from every element in the matrix
                            ScalarOperator minusOp = new RightScalarOperator(Minus.getMinusFnObject(), mean);
                            matrixScalarOp(gCtx, instName, in, mean, rlen, clen, tmp, minusOp);
                            squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
                            double result2 = reduceAll(gCtx, instName, "reduce_sum", tmp2, size);
                            double variance = result2 / (size - 1);
                            ec.setScalarOutput(output, new DoubleObject(variance));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
                            // Subtract the row-wise mean from every element in the matrix
                            BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
                            matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
                            squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
                            Pointer tmpRow = gCtx.allocate(instName, rlen * Sizeof.DOUBLE);
                            reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
                            ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
                            matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
                            gCtx.cudaFreeHelper(instName, tmpRow);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
                            // Subtract the columns-wise mean from every element in the matrix
                            BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
                            matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
                            squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
                            Pointer tmpCol = gCtx.allocate(instName, clen * Sizeof.DOUBLE);
                            reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
                            ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
                            matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
                            gCtx.cudaFreeHelper(instName, tmpCol);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance");
                }
                gCtx.cudaFreeHelper(instName, tmp);
                gCtx.cudaFreeHelper(instName, tmp2);
                break;
            }
        case OP_MAXINDEX:
            {
                switch(reductionDirection) {
                    case REDUCTION_COL:
                        throw new DMLRuntimeException("Internal Error - Column maxindex of matrix not implemented yet for GPU ");
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for maxindex");
                }
            // break;
            }
        case OP_MININDEX:
            {
                switch(reductionDirection) {
                    case REDUCTION_COL:
                        throw new DMLRuntimeException("Internal Error - Column minindex of matrix not implemented yet for GPU ");
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for minindex");
                }
            // break;
            }
        default:
            throw new DMLRuntimeException("Internal Error - Invalid GPU Unary aggregate function!");
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ScalarOperator(org.apache.sysml.runtime.matrix.operators.ScalarOperator) LeftScalarOperator(org.apache.sysml.runtime.matrix.operators.LeftScalarOperator) RightScalarOperator(org.apache.sysml.runtime.matrix.operators.RightScalarOperator) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) Mean(org.apache.sysml.runtime.functionobjects.Mean) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) ReduceDiag(org.apache.sysml.runtime.functionobjects.ReduceDiag) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) CM(org.apache.sysml.runtime.functionobjects.CM) CSRPointer(org.apache.sysml.runtime.instructions.gpu.context.CSRPointer) Pointer(jcuda.Pointer) RightScalarOperator(org.apache.sysml.runtime.matrix.operators.RightScalarOperator) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IndexFunction(org.apache.sysml.runtime.functionobjects.IndexFunction) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) Plus(org.apache.sysml.runtime.functionobjects.Plus) BinaryOperator(org.apache.sysml.runtime.matrix.operators.BinaryOperator) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Aggregations

CM (org.apache.sysml.runtime.functionobjects.CM)13 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)7 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)7 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)7 CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)6 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)5 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)5 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)4 Mean (org.apache.sysml.runtime.functionobjects.Mean)4 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)4 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)4 ReduceDiag (org.apache.sysml.runtime.functionobjects.ReduceDiag)4 ReduceRow (org.apache.sysml.runtime.functionobjects.ReduceRow)4 WeightedCell (org.apache.sysml.runtime.matrix.data.WeightedCell)4 CMOperator (org.apache.sysml.runtime.matrix.operators.CMOperator)4 IOException (java.io.IOException)3 CorrectionLocationType (org.apache.sysml.lops.PartialAggregate.CorrectionLocationType)2 IndexFunction (org.apache.sysml.runtime.functionobjects.IndexFunction)2 Multiply (org.apache.sysml.runtime.functionobjects.Multiply)2 GroupedAggregateInstruction (org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction)2