Search in sources :

Example 1 with Multiply2

use of org.apache.sysml.runtime.functionobjects.Multiply2 in project incubator-systemml by apache.

the class LibMatrixBincell method safeBinaryScalar.

private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret, ScalarOperator op) {
    // early abort possible since sparsesafe
    if (m1.isEmptyBlock(false)) {
        return;
    }
    // sanity check input/output sparsity
    if (m1.sparse != ret.sparse)
        throw new DMLRuntimeException("Unsupported safe binary scalar operations over different input/output representation: " + m1.sparse + " " + ret.sparse);
    boolean copyOnes = (op.fn instanceof NotEquals && op.getConstant() == 0);
    boolean allocExact = (op.fn instanceof Multiply || op.fn instanceof Multiply2 || op.fn instanceof Power2 || Builtin.isBuiltinCode(op.fn, BuiltinCode.MAX) || Builtin.isBuiltinCode(op.fn, BuiltinCode.MIN));
    if (// SPARSE <- SPARSE
    m1.sparse) {
        // allocate sparse row structure
        ret.allocateSparseRowsBlock();
        SparseBlock a = m1.sparseBlock;
        SparseBlock c = ret.sparseBlock;
        int rlen = Math.min(m1.rlen, a.numRows());
        long nnz = 0;
        for (int r = 0; r < rlen; r++) {
            if (a.isEmpty(r))
                continue;
            int apos = a.pos(r);
            int alen = a.size(r);
            int[] aix = a.indexes(r);
            double[] avals = a.values(r);
            if (copyOnes) {
                // SPECIAL CASE: e.g., (X != 0)
                // create sparse row without repeated resizing
                SparseRowVector crow = new SparseRowVector(alen);
                crow.setSize(alen);
                // memcopy/memset of indexes/values (sparseblock guarantees absence of 0s)
                System.arraycopy(aix, apos, crow.indexes(), 0, alen);
                Arrays.fill(crow.values(), 0, alen, 1);
                c.set(r, crow, false);
                nnz += alen;
            } else {
                // create sparse row without repeated resizing for specific ops
                if (allocExact)
                    c.allocate(r, alen);
                for (int j = apos; j < apos + alen; j++) {
                    double val = op.executeScalar(avals[j]);
                    c.append(r, aix[j], val);
                    nnz += (val != 0) ? 1 : 0;
                }
            }
        }
        ret.nonZeros = nnz;
    } else {
        // DENSE <- DENSE
        denseBinaryScalar(m1, ret, op);
    }
}
Also used : Multiply2(org.apache.sysml.runtime.functionobjects.Multiply2) NotEquals(org.apache.sysml.runtime.functionobjects.NotEquals) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) MinusMultiply(org.apache.sysml.runtime.functionobjects.MinusMultiply) PlusMultiply(org.apache.sysml.runtime.functionobjects.PlusMultiply) Power2(org.apache.sysml.runtime.functionobjects.Power2)

Example 2 with Multiply2

use of org.apache.sysml.runtime.functionobjects.Multiply2 in project systemml by apache.

the class LibMatrixBincell method safeBinaryScalar.

private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret, ScalarOperator op) {
    // early abort possible since sparsesafe
    if (m1.isEmptyBlock(false)) {
        return;
    }
    // sanity check input/output sparsity
    if (m1.sparse != ret.sparse)
        throw new DMLRuntimeException("Unsupported safe binary scalar operations over different input/output representation: " + m1.sparse + " " + ret.sparse);
    boolean copyOnes = (op.fn instanceof NotEquals && op.getConstant() == 0);
    boolean allocExact = (op.fn instanceof Multiply || op.fn instanceof Multiply2 || op.fn instanceof Power2 || Builtin.isBuiltinCode(op.fn, BuiltinCode.MAX) || Builtin.isBuiltinCode(op.fn, BuiltinCode.MIN));
    if (// SPARSE <- SPARSE
    m1.sparse) {
        // allocate sparse row structure
        ret.allocateSparseRowsBlock();
        SparseBlock a = m1.sparseBlock;
        SparseBlock c = ret.sparseBlock;
        int rlen = Math.min(m1.rlen, a.numRows());
        long nnz = 0;
        for (int r = 0; r < rlen; r++) {
            if (a.isEmpty(r))
                continue;
            int apos = a.pos(r);
            int alen = a.size(r);
            int[] aix = a.indexes(r);
            double[] avals = a.values(r);
            if (copyOnes) {
                // SPECIAL CASE: e.g., (X != 0)
                // create sparse row without repeated resizing
                SparseRowVector crow = new SparseRowVector(alen);
                crow.setSize(alen);
                // memcopy/memset of indexes/values (sparseblock guarantees absence of 0s)
                System.arraycopy(aix, apos, crow.indexes(), 0, alen);
                Arrays.fill(crow.values(), 0, alen, 1);
                c.set(r, crow, false);
                nnz += alen;
            } else {
                // create sparse row without repeated resizing for specific ops
                if (allocExact)
                    c.allocate(r, alen);
                for (int j = apos; j < apos + alen; j++) {
                    double val = op.executeScalar(avals[j]);
                    c.append(r, aix[j], val);
                    nnz += (val != 0) ? 1 : 0;
                }
            }
        }
        ret.nonZeros = nnz;
    } else {
        // DENSE <- DENSE
        denseBinaryScalar(m1, ret, op);
    }
}
Also used : Multiply2(org.apache.sysml.runtime.functionobjects.Multiply2) NotEquals(org.apache.sysml.runtime.functionobjects.NotEquals) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) MinusMultiply(org.apache.sysml.runtime.functionobjects.MinusMultiply) PlusMultiply(org.apache.sysml.runtime.functionobjects.PlusMultiply) Power2(org.apache.sysml.runtime.functionobjects.Power2)

Aggregations

DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 MinusMultiply (org.apache.sysml.runtime.functionobjects.MinusMultiply)2 Multiply (org.apache.sysml.runtime.functionobjects.Multiply)2 Multiply2 (org.apache.sysml.runtime.functionobjects.Multiply2)2 NotEquals (org.apache.sysml.runtime.functionobjects.NotEquals)2 PlusMultiply (org.apache.sysml.runtime.functionobjects.PlusMultiply)2 Power2 (org.apache.sysml.runtime.functionobjects.Power2)2