Search in sources :

Example 1 with Or

use of org.apache.sysml.runtime.functionobjects.Or in project incubator-systemml by apache.

the class LibMatrixCUDA method matrixScalarArithmetic.

//********************************************************************/
//****************  END OF UNARY AGGREGATE Functions *****************/
//********************************************************************/
//********************************************************************/
//************ Matrix-Matrix & Matrix-Scalar Functions ***************/
//********************************************************************/
/**
	 * Entry point to perform elementwise matrix-scalar operation specified by op
	 *
	 * @param ec execution context
	 * @param gCtx a valid {@link GPUContext}
	 * @param instName the invoking instruction's name for record {@link Statistics}.
	 * @param in input matrix
	 * @param outputName output matrix name
	 * @param isInputTransposed true if input transposed
	 * @param op scalar operator
	 * @throws DMLRuntimeException if DMLRuntimeException occurs
	 */
public static void matrixScalarArithmetic(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName, boolean isInputTransposed, ScalarOperator op) throws DMLRuntimeException {
    if (ec.getGPUContext() != gCtx)
        throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
    double constant = op.getConstant();
    LOG.trace("GPU : matrixScalarArithmetic, scalar: " + constant + ", GPUContext=" + gCtx);
    //if(!isCUDALibAvailable) {
    if (constant == 0) {
        if (op.fn instanceof Plus || (op.fn instanceof Minus && op instanceof RightScalarOperator) || op.fn instanceof Or) {
            deviceCopy(ec, gCtx, instName, in, outputName, isInputTransposed);
        } else if (op.fn instanceof Multiply || op.fn instanceof And) {
            setOutputToConstant(ec, gCtx, instName, 0.0, outputName);
        } else if (op.fn instanceof Power) {
            setOutputToConstant(ec, gCtx, instName, 1.0, outputName);
        } else // TODO:
        // x/0.0 is either +Infinity or -Infinity according to Java.
        // In the context of a matrix, different elements of the matrix
        // could have different values.
        // If the IEEE 754 standard defines otherwise, this logic needs
        // to be re-enabled and the Java computation logic for divide by zero
        // needs to be revisited
        //else if(op.fn instanceof Divide && isSparseAndEmpty(gCtx, in)) {
        //	setOutputToConstant(ec, gCtx, instName, Double.NaN, outputName);
        //}
        //else if(op.fn instanceof Divide) {
        //	//For division, IEEE 754 defines x/0.0 as INFINITY and 0.0/0.0 as NaN.
        //	compareAndSet(ec, gCtx, instName, in, outputName, 0.0, 1e-6, Double.NaN, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
        //}
        {
            // TODO: Potential to optimize
            matrixScalarOp(ec, gCtx, instName, in, outputName, isInputTransposed, op);
        }
    } else if (constant == 1.0 && op.fn instanceof Or) {
        setOutputToConstant(ec, gCtx, instName, 1.0, outputName);
    } else if (constant == 1.0 && (op.fn instanceof And || op.fn instanceof Power)) {
        deviceCopy(ec, gCtx, instName, in, outputName, isInputTransposed);
    } else {
        matrixScalarOp(ec, gCtx, instName, in, outputName, isInputTransposed, op);
    }
// }
//else {
//	double alpha = 0;
//	if(op.fn instanceof Multiply) {
//		alpha = op.getConstant();
//	}
//	else if(op.fn instanceof Divide && op instanceof RightScalarOperator) {
//		alpha = Math.pow(op.getConstant(), -1.0);
//	}
//	else {
//		throw new DMLRuntimeException("Unsupported op");
//	}
// TODO: Performance optimization: Call cublasDaxpy if(in.getNumRows() == 1 || in.getNumColumns() == 1)
// C = alpha* op( A ) + beta* op ( B )
//	dgeam(ec, gCtx, instName, in, in, outputName, isInputTransposed, isInputTransposed, alpha, 0.0);
//}
}
Also used : Or(org.apache.sysml.runtime.functionobjects.Or) And(org.apache.sysml.runtime.functionobjects.And) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) Plus(org.apache.sysml.runtime.functionobjects.Plus) RightScalarOperator(org.apache.sysml.runtime.matrix.operators.RightScalarOperator) Minus(org.apache.sysml.runtime.functionobjects.Minus) Power(org.apache.sysml.runtime.functionobjects.Power) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 And (org.apache.sysml.runtime.functionobjects.And)1 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)1 Minus (org.apache.sysml.runtime.functionobjects.Minus)1 Multiply (org.apache.sysml.runtime.functionobjects.Multiply)1 Or (org.apache.sysml.runtime.functionobjects.Or)1 Plus (org.apache.sysml.runtime.functionobjects.Plus)1 Power (org.apache.sysml.runtime.functionobjects.Power)1 RightScalarOperator (org.apache.sysml.runtime.matrix.operators.RightScalarOperator)1