Search in sources :

Example 6 with Builtin

use of org.apache.sysml.runtime.functionobjects.Builtin in project incubator-systemml by apache.

the class LibMatrixCUDA method unaryAggregate.

// ********************************************************************/
// ******** End of TRANSPOSE SELF MATRIX MULTIPLY Functions ***********/
// ********************************************************************/
// ********************************************************************/
// ****************  UNARY AGGREGATE Functions ************************/
// ********************************************************************/
/**
 * Entry point to perform Unary aggregate operations on the GPU.
 * The execution context object is used to allocate memory for the GPU.
 *
 * @param ec       Instance of {@link ExecutionContext}, from which the output variable will be allocated
 * @param gCtx     a valid {@link GPUContext}
 * @param instName name of the invoking instruction to record{@link Statistics}.
 * @param in1      input matrix
 * @param output   output matrix/scalar name
 * @param op       Instance of {@link AggregateUnaryOperator} which encapsulates the direction of reduction/aggregation and the reduction operation.
 */
public static void unaryAggregate(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String output, AggregateUnaryOperator op) {
    if (ec.getGPUContext(0) != gCtx)
        throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
    if (LOG.isTraceEnabled()) {
        LOG.trace("GPU : unaryAggregate" + ", GPUContext=" + gCtx);
    }
    final int REDUCTION_ALL = 1;
    final int REDUCTION_ROW = 2;
    final int REDUCTION_COL = 3;
    final int REDUCTION_DIAG = 4;
    // A kahan sum implemention is not provided. is a "uak+" or other kahan operator is encountered,
    // it just does regular summation reduction.
    final int OP_PLUS = 1;
    final int OP_PLUS_SQ = 2;
    final int OP_MEAN = 3;
    final int OP_VARIANCE = 4;
    final int OP_MULTIPLY = 5;
    final int OP_MAX = 6;
    final int OP_MIN = 7;
    final int OP_MAXINDEX = 8;
    final int OP_MININDEX = 9;
    // Sanity Checks
    if (!in1.getGPUObject(gCtx).isAllocated())
        throw new DMLRuntimeException("Internal Error - The input is not allocated for a GPU Aggregate Unary:" + in1.getGPUObject(gCtx).isAllocated());
    boolean isSparse = in1.getGPUObject(gCtx).isSparse();
    IndexFunction indexFn = op.indexFn;
    AggregateOperator aggOp = op.aggOp;
    // Convert Reduction direction to a number
    int reductionDirection = -1;
    if (indexFn instanceof ReduceAll) {
        reductionDirection = REDUCTION_ALL;
    } else if (indexFn instanceof ReduceRow) {
        reductionDirection = REDUCTION_ROW;
    } else if (indexFn instanceof ReduceCol) {
        reductionDirection = REDUCTION_COL;
    } else if (indexFn instanceof ReduceDiag) {
        reductionDirection = REDUCTION_DIAG;
    } else {
        throw new DMLRuntimeException("Internal Error - Invalid index function type, only reducing along rows, columns, diagonals or all elements is supported in Aggregate Unary operations");
    }
    if (reductionDirection == -1)
        throw new DMLRuntimeException("Internal Error - Incorrect type of reduction direction set for aggregate unary GPU instruction");
    // Convert function type to a number
    int opIndex = -1;
    if (aggOp.increOp.fn instanceof KahanPlus) {
        opIndex = OP_PLUS;
    } else if (aggOp.increOp.fn instanceof KahanPlusSq) {
        opIndex = OP_PLUS_SQ;
    } else if (aggOp.increOp.fn instanceof Mean) {
        opIndex = OP_MEAN;
    } else if (aggOp.increOp.fn instanceof CM) {
        if (((CM) aggOp.increOp.fn).getAggOpType() != CMOperator.AggregateOperationTypes.VARIANCE)
            throw new DMLRuntimeException("Internal Error - Invalid Type of CM operator for Aggregate Unary operation on GPU");
        opIndex = OP_VARIANCE;
    } else if (aggOp.increOp.fn instanceof Plus) {
        opIndex = OP_PLUS;
    } else if (aggOp.increOp.fn instanceof Multiply) {
        opIndex = OP_MULTIPLY;
    } else if (aggOp.increOp.fn instanceof Builtin) {
        Builtin b = (Builtin) aggOp.increOp.fn;
        switch(b.bFunc) {
            case MAX:
                opIndex = OP_MAX;
                break;
            case MIN:
                opIndex = OP_MIN;
                break;
            case MAXINDEX:
                opIndex = OP_MAXINDEX;
                break;
            case MININDEX:
                opIndex = OP_MININDEX;
                break;
            default:
                new DMLRuntimeException("Internal Error - Unsupported Builtin Function for Aggregate unary being done on GPU");
        }
    } else {
        throw new DMLRuntimeException("Internal Error - Aggregate operator has invalid Value function");
    }
    if (opIndex == -1)
        throw new DMLRuntimeException("Internal Error - Incorrect type of operation set for aggregate unary GPU instruction");
    int rlen = (int) in1.getNumRows();
    int clen = (int) in1.getNumColumns();
    if (isSparse) {
        // The strategy for the time being is to convert sparse to dense
        // until a sparse specific kernel is written.
        in1.getGPUObject(gCtx).sparseToDense(instName);
    // long nnz = in1.getNnz();
    // assert nnz > 0 : "Internal Error - number of non zeroes set to " + nnz + " in Aggregate Binary for GPU";
    // MatrixObject out = ec.getSparseMatrixOutputForGPUInstruction(output, nnz);
    // throw new DMLRuntimeException("Internal Error - Not implemented");
    }
    long outRLen = -1;
    long outCLen = -1;
    if (indexFn instanceof ReduceRow) {
        // COL{SUM, MAX...}
        outRLen = 1;
        outCLen = clen;
    } else if (indexFn instanceof ReduceCol) {
        // ROW{SUM, MAX,...}
        outRLen = rlen;
        outCLen = 1;
    }
    Pointer out = null;
    if (reductionDirection == REDUCTION_COL || reductionDirection == REDUCTION_ROW) {
        // Matrix output
        MatrixObject out1 = getDenseMatrixOutputForGPUInstruction(ec, instName, output, outRLen, outCLen);
        out = getDensePointer(gCtx, out1, instName);
    }
    Pointer in = getDensePointer(gCtx, in1, instName);
    int size = rlen * clen;
    // For scalars, set the scalar output in the Execution Context object
    switch(opIndex) {
        case OP_PLUS:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            // The names are a bit misleading, REDUCTION_COL refers to the direction (reduce all elements in a column)
                            reduceRow(gCtx, instName, "reduce_row_sum", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_sum", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_DIAG:
                        throw new DMLRuntimeException("Internal Error - Row, Column and Diag summation not implemented yet");
                }
                break;
            }
        case OP_PLUS_SQ:
            {
                // Calculate the squares in a temporary object tmp
                Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
                squareMatrix(gCtx, instName, in, tmp, rlen, clen);
                // Then do the sum on the temporary object and free it
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", tmp, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            // The names are a bit misleading, REDUCTION_COL refers to the direction (reduce all elements in a column)
                            reduceRow(gCtx, instName, "reduce_row_sum", tmp, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_sum", tmp, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for summation squared");
                }
                gCtx.cudaFreeHelper(instName, tmp);
                break;
            }
        case OP_MEAN:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
                            double mean = result / size;
                            ec.setScalarOutput(output, new DoubleObject(mean));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for mean");
                }
                break;
            }
        case OP_MULTIPLY:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_prod", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for multiplication");
                }
                break;
            }
        case OP_MAX:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_max", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_max", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_max", in, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for max");
                }
                break;
            }
        case OP_MIN:
            {
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_min", in, size);
                            ec.setScalarOutput(output, new DoubleObject(result));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_min", in, out, rlen, clen);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_min", in, out, rlen, clen);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for min");
                }
                break;
            }
        case OP_VARIANCE:
            {
                // Temporary GPU array for
                Pointer tmp = gCtx.allocate(instName, size * sizeOfDataType);
                Pointer tmp2 = gCtx.allocate(instName, size * sizeOfDataType);
                switch(reductionDirection) {
                    case REDUCTION_ALL:
                        {
                            double result = reduceAll(gCtx, instName, "reduce_sum", in, size);
                            double mean = result / size;
                            // Subtract mean from every element in the matrix
                            ScalarOperator minusOp = new RightScalarOperator(Minus.getMinusFnObject(), mean);
                            matrixScalarOp(gCtx, instName, in, mean, rlen, clen, tmp, minusOp);
                            squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
                            double result2 = reduceAll(gCtx, instName, "reduce_sum", tmp2, size);
                            double variance = result2 / (size - 1);
                            ec.setScalarOutput(output, new DoubleObject(variance));
                            break;
                        }
                    case REDUCTION_COL:
                        {
                            reduceRow(gCtx, instName, "reduce_row_mean", in, out, rlen, clen);
                            // Subtract the row-wise mean from every element in the matrix
                            BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
                            matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.COLUMN.code(), tmp, minusOp);
                            squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
                            Pointer tmpRow = gCtx.allocate(instName, rlen * sizeOfDataType);
                            reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen);
                            ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1);
                            matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp);
                            gCtx.cudaFreeHelper(instName, tmpRow);
                            break;
                        }
                    case REDUCTION_ROW:
                        {
                            reduceCol(gCtx, instName, "reduce_col_mean", in, out, rlen, clen);
                            // Subtract the columns-wise mean from every element in the matrix
                            BinaryOperator minusOp = new BinaryOperator(Minus.getMinusFnObject());
                            matrixMatrixOp(gCtx, instName, in, out, rlen, clen, VectorShape.NONE.code(), VectorShape.ROW.code(), tmp, minusOp);
                            squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen);
                            Pointer tmpCol = gCtx.allocate(instName, clen * sizeOfDataType);
                            reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen);
                            ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1);
                            matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp);
                            gCtx.cudaFreeHelper(instName, tmpCol);
                            break;
                        }
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for variance");
                }
                gCtx.cudaFreeHelper(instName, tmp);
                gCtx.cudaFreeHelper(instName, tmp2);
                break;
            }
        case OP_MAXINDEX:
            {
                switch(reductionDirection) {
                    case REDUCTION_COL:
                        throw new DMLRuntimeException("Internal Error - Column maxindex of matrix not implemented yet for GPU ");
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for maxindex");
                }
            // break;
            }
        case OP_MININDEX:
            {
                switch(reductionDirection) {
                    case REDUCTION_COL:
                        throw new DMLRuntimeException("Internal Error - Column minindex of matrix not implemented yet for GPU ");
                    default:
                        throw new DMLRuntimeException("Internal Error - Unsupported reduction direction for minindex");
                }
            // break;
            }
        default:
            throw new DMLRuntimeException("Internal Error - Invalid GPU Unary aggregate function!");
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ScalarOperator(org.apache.sysml.runtime.matrix.operators.ScalarOperator) LeftScalarOperator(org.apache.sysml.runtime.matrix.operators.LeftScalarOperator) RightScalarOperator(org.apache.sysml.runtime.matrix.operators.RightScalarOperator) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) Mean(org.apache.sysml.runtime.functionobjects.Mean) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) ReduceDiag(org.apache.sysml.runtime.functionobjects.ReduceDiag) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) CM(org.apache.sysml.runtime.functionobjects.CM) CSRPointer(org.apache.sysml.runtime.instructions.gpu.context.CSRPointer) Pointer(jcuda.Pointer) RightScalarOperator(org.apache.sysml.runtime.matrix.operators.RightScalarOperator) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IndexFunction(org.apache.sysml.runtime.functionobjects.IndexFunction) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) Minus1Multiply(org.apache.sysml.runtime.functionobjects.Minus1Multiply) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) Plus(org.apache.sysml.runtime.functionobjects.Plus) BinaryOperator(org.apache.sysml.runtime.matrix.operators.BinaryOperator) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 7 with Builtin

use of org.apache.sysml.runtime.functionobjects.Builtin in project incubator-systemml by apache.

the class UaggOuterChainSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || _uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp));
    String rddVar = (rightCached) ? input1.getName() : input2.getName();
    String bcastVar = (rightCached) ? input2.getName() : input1.getName();
    // get rdd input
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
    boolean noKeyChange = preservesPartitioning(mcIn, _uaggOp.indexFn);
    // execute UAggOuterChain instruction
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    if (LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp)) {
        // create sorted broadcast matrix
        MatrixBlock mb = sec.getMatrixInput(bcastVar, getExtendedOpcode());
        sec.releaseMatrixInput(bcastVar, getExtendedOpcode());
        // prevent lineage tracking
        bcastVar = null;
        double[] vmb = DataConverter.convertToDoubleVector(mb);
        Broadcast<int[]> bvi = null;
        if (_uaggOp.aggOp.increOp.fn instanceof Builtin) {
            int[] vix = LibMatrixOuterAgg.prepareRowIndices(mb.getNumColumns(), vmb, _bOp, _uaggOp);
            bvi = sec.getSparkContext().broadcast(vix);
        } else
            Arrays.sort(vmb);
        Broadcast<double[]> bv = sec.getSparkContext().broadcast(vmb);
        // partitioning-preserving map-to-pair (under constraints)
        out = in1.mapPartitionsToPair(new RDDMapUAggOuterChainFunction(bv, bvi, _bOp, _uaggOp), noKeyChange);
    } else {
        PartitionedBroadcast<MatrixBlock> bv = sec.getBroadcastForVariable(bcastVar);
        // partitioning-preserving map-to-pair (under constraints)
        out = in1.mapPartitionsToPair(new RDDMapGenUAggOuterChainFunction(bv, _uaggOp, _aggOp, _bOp, mcIn), noKeyChange);
    }
    // final aggregation if required
    if (// RC AGG (output is scalar)
    _uaggOp.indexFn instanceof ReduceAll) {
        MatrixBlock tmp = RDDAggregateUtils.aggStable(out, _aggOp);
        // drop correction after aggregation
        tmp.dropLastRowsOrColumns(_aggOp.correctionLocation);
        // put output block into symbol table (no lineage because single block)
        sec.setMatrixOutput(output.getName(), tmp, getExtendedOpcode());
    } else // R/C AGG (output is rdd)
    {
        // put output RDD handle into symbol table
        updateUnaryAggOutputMatrixCharacteristics(sec);
        if (_uaggOp.aggOp.correctionExists)
            out = out.mapValues(new AggregateDropCorrectionFunction(_uaggOp.aggOp));
        sec.setRDDHandleForVariable(output.getName(), out);
        sec.addLineageRDD(output.getName(), rddVar);
        if (bcastVar != null)
            sec.addLineageBroadcast(output.getName(), bcastVar);
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) AggregateDropCorrectionFunction(org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 8 with Builtin

use of org.apache.sysml.runtime.functionobjects.Builtin in project incubator-systemml by apache.

the class LibMatrixAgg method getAggType.

private static AggType getAggType(AggregateUnaryOperator op) {
    ValueFunction vfn = op.aggOp.increOp.fn;
    IndexFunction ifn = op.indexFn;
    // (kahan) sum / sum squared / trace (for ReduceDiag)
    if (vfn instanceof KahanFunction && (op.aggOp.correctionLocation == CorrectionLocationType.LASTCOLUMN || op.aggOp.correctionLocation == CorrectionLocationType.LASTROW) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow || ifn instanceof ReduceDiag)) {
        if (vfn instanceof KahanPlus)
            return AggType.KAHAN_SUM;
        else if (vfn instanceof KahanPlusSq)
            return AggType.KAHAN_SUM_SQ;
    }
    // mean
    if (vfn instanceof Mean && (op.aggOp.correctionLocation == CorrectionLocationType.LASTTWOCOLUMNS || op.aggOp.correctionLocation == CorrectionLocationType.LASTTWOROWS) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
        return AggType.MEAN;
    }
    // variance
    if (vfn instanceof CM && ((CM) vfn).getAggOpType() == AggregateOperationTypes.VARIANCE && (op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS || op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
        return AggType.VAR;
    }
    // prod
    if (vfn instanceof Multiply && ifn instanceof ReduceAll) {
        return AggType.PROD;
    }
    // min / max
    if (vfn instanceof Builtin && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
        BuiltinCode bfcode = ((Builtin) vfn).bFunc;
        switch(bfcode) {
            case MAX:
                return AggType.MAX;
            case MIN:
                return AggType.MIN;
            case MAXINDEX:
                return AggType.MAX_INDEX;
            case MININDEX:
                return AggType.MIN_INDEX;
            // do nothing
            default:
        }
    }
    return AggType.INVALID;
}
Also used : ValueFunction(org.apache.sysml.runtime.functionobjects.ValueFunction) ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) Mean(org.apache.sysml.runtime.functionobjects.Mean) ReduceDiag(org.apache.sysml.runtime.functionobjects.ReduceDiag) CM(org.apache.sysml.runtime.functionobjects.CM) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) IndexFunction(org.apache.sysml.runtime.functionobjects.IndexFunction) BuiltinCode(org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode) KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 9 with Builtin

use of org.apache.sysml.runtime.functionobjects.Builtin in project incubator-systemml by apache.

the class MatrixBlock method incrementalAggregate.

@Override
public void incrementalAggregate(AggregateOperator aggOp, MatrixValue newWithCorrection) {
    // assert(aggOp.correctionExists);
    MatrixBlock newWithCor = checkType(newWithCorrection);
    KahanObject buffer = new KahanObject(0, 0);
    if (aggOp.correctionLocation == CorrectionLocationType.LASTROW) {
        if (aggOp.increOp.fn instanceof KahanPlus) {
            LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
        } else {
            for (int r = 0; r < rlen - 1; r++) for (int c = 0; c < clen; c++) {
                buffer._sum = this.quickGetValue(r, c);
                buffer._correction = this.quickGetValue(r + 1, c);
                buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.quickGetValue(r, c), newWithCor.quickGetValue(r + 1, c));
                quickSetValue(r, c, buffer._sum);
                quickSetValue(r + 1, c, buffer._correction);
            }
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTCOLUMN) {
        if (aggOp.increOp.fn instanceof Builtin && (((Builtin) (aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX || ((Builtin) (aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX)) {
            // modified, the other needs to be changed to match.
            for (int r = 0; r < rlen; r++) {
                double currMaxValue = quickGetValue(r, 1);
                long newMaxIndex = (long) newWithCor.quickGetValue(r, 0);
                double newMaxValue = newWithCor.quickGetValue(r, 1);
                double update = aggOp.increOp.fn.execute(newMaxValue, currMaxValue);
                if (2.0 == update) {
                    // Return value of 2 ==> both values the same, break ties
                    // in favor of higher index.
                    long curMaxIndex = (long) quickGetValue(r, 0);
                    quickSetValue(r, 0, Math.max(curMaxIndex, newMaxIndex));
                } else if (1.0 == update) {
                    // Return value of 1 ==> new value is better; use its index
                    quickSetValue(r, 0, newMaxIndex);
                    quickSetValue(r, 1, newMaxValue);
                } else {
                // Other return value ==> current answer is best
                }
            }
        // *** END HACK ***
        } else {
            if (aggOp.increOp.fn instanceof KahanPlus) {
                LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
            } else {
                for (int r = 0; r < rlen; r++) for (int c = 0; c < clen - 1; c++) {
                    buffer._sum = this.quickGetValue(r, c);
                    buffer._correction = this.quickGetValue(r, c + 1);
                    buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.quickGetValue(r, c), newWithCor.quickGetValue(r, c + 1));
                    quickSetValue(r, c, buffer._sum);
                    quickSetValue(r, c + 1, buffer._correction);
                }
            }
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTTWOROWS) {
        double n, n2, mu2;
        for (int r = 0; r < rlen - 2; r++) for (int c = 0; c < clen; c++) {
            buffer._sum = this.quickGetValue(r, c);
            n = this.quickGetValue(r + 1, c);
            buffer._correction = this.quickGetValue(r + 2, c);
            mu2 = newWithCor.quickGetValue(r, c);
            n2 = newWithCor.quickGetValue(r + 1, c);
            n = n + n2;
            double toadd = (mu2 - buffer._sum) * n2 / n;
            buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
            quickSetValue(r, c, buffer._sum);
            quickSetValue(r + 1, c, n);
            quickSetValue(r + 2, c, buffer._correction);
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTTWOCOLUMNS) {
        double n, n2, mu2;
        for (int r = 0; r < rlen; r++) for (int c = 0; c < clen - 2; c++) {
            buffer._sum = this.quickGetValue(r, c);
            n = this.quickGetValue(r, c + 1);
            buffer._correction = this.quickGetValue(r, c + 2);
            mu2 = newWithCor.quickGetValue(r, c);
            n2 = newWithCor.quickGetValue(r, c + 1);
            n = n + n2;
            double toadd = (mu2 - buffer._sum) * n2 / n;
            buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
            quickSetValue(r, c, buffer._sum);
            quickSetValue(r, c + 1, n);
            quickSetValue(r, c + 2, buffer._correction);
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS && aggOp.increOp.fn instanceof CM && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
        // create buffers to store results
        CM_COV_Object cbuff_curr = new CM_COV_Object();
        CM_COV_Object cbuff_part = new CM_COV_Object();
        // perform incremental aggregation
        for (int r = 0; r < rlen - 4; r++) for (int c = 0; c < clen; c++) {
            // extract current values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_curr.w = quickGetValue(r + 2, c);
            // m2
            cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1);
            // mean
            cbuff_curr.mean._sum = quickGetValue(r + 1, c);
            cbuff_curr.m2._correction = quickGetValue(r + 3, c);
            cbuff_curr.mean._correction = quickGetValue(r + 4, c);
            // extract partial values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_part.w = newWithCor.quickGetValue(r + 2, c);
            // m2
            cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1);
            // mean
            cbuff_part.mean._sum = newWithCor.quickGetValue(r + 1, c);
            cbuff_part.m2._correction = newWithCor.quickGetValue(r + 3, c);
            cbuff_part.mean._correction = newWithCor.quickGetValue(r + 4, c);
            // calculate incremental aggregated variance
            cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
            // store updated values: { var | mean, count, m2 correction, mean correction }
            double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
            quickSetValue(r, c, var);
            // mean
            quickSetValue(r + 1, c, cbuff_curr.mean._sum);
            // count
            quickSetValue(r + 2, c, cbuff_curr.w);
            quickSetValue(r + 3, c, cbuff_curr.m2._correction);
            quickSetValue(r + 4, c, cbuff_curr.mean._correction);
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS && aggOp.increOp.fn instanceof CM && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
        // create buffers to store results
        CM_COV_Object cbuff_curr = new CM_COV_Object();
        CM_COV_Object cbuff_part = new CM_COV_Object();
        // perform incremental aggregation
        for (int r = 0; r < rlen; r++) for (int c = 0; c < clen - 4; c++) {
            // extract current values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_curr.w = quickGetValue(r, c + 2);
            // m2
            cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1);
            // mean
            cbuff_curr.mean._sum = quickGetValue(r, c + 1);
            cbuff_curr.m2._correction = quickGetValue(r, c + 3);
            cbuff_curr.mean._correction = quickGetValue(r, c + 4);
            // extract partial values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_part.w = newWithCor.quickGetValue(r, c + 2);
            // m2
            cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1);
            // mean
            cbuff_part.mean._sum = newWithCor.quickGetValue(r, c + 1);
            cbuff_part.m2._correction = newWithCor.quickGetValue(r, c + 3);
            cbuff_part.mean._correction = newWithCor.quickGetValue(r, c + 4);
            // calculate incremental aggregated variance
            cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
            // store updated values: { var | mean, count, m2 correction, mean correction }
            double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
            quickSetValue(r, c, var);
            // mean
            quickSetValue(r, c + 1, cbuff_curr.mean._sum);
            // count
            quickSetValue(r, c + 2, cbuff_curr.w);
            quickSetValue(r, c + 3, cbuff_curr.m2._correction);
            quickSetValue(r, c + 4, cbuff_curr.mean._correction);
        }
    } else
        throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correctionLocation);
}
Also used : CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) CM(org.apache.sysml.runtime.functionobjects.CM) Builtin(org.apache.sysml.runtime.functionobjects.Builtin) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

Builtin (org.apache.sysml.runtime.functionobjects.Builtin)9 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)8 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)7 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)6 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)6 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)6 CM (org.apache.sysml.runtime.functionobjects.CM)5 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)5 Mean (org.apache.sysml.runtime.functionobjects.Mean)4 ReduceDiag (org.apache.sysml.runtime.functionobjects.ReduceDiag)4 ReduceRow (org.apache.sysml.runtime.functionobjects.ReduceRow)4 CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)3 IndexFunction (org.apache.sysml.runtime.functionobjects.IndexFunction)2 KahanFunction (org.apache.sysml.runtime.functionobjects.KahanFunction)2 Multiply (org.apache.sysml.runtime.functionobjects.Multiply)2 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)2 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 ExecutorService (java.util.concurrent.ExecutorService)1 Future (java.util.concurrent.Future)1