Search in sources :

Example 6 with KahanFunction

use of org.apache.sysml.runtime.functionobjects.KahanFunction in project incubator-systemml by apache.

the class SpoofCellwise method executeSparseAggSum.

private double executeSparseAggSum(SparseBlock sblock, SideInput[] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) {
    KahanFunction kplus = (KahanFunction) getAggFunction();
    KahanObject kbuff = new KahanObject(0, 0);
    // in order to avoid binary search for sparse-unsafe
    for (int i = rl; i < ru; i++) {
        int lastj = -1;
        // handle non-empty rows
        if (sblock != null && !sblock.isEmpty(i)) {
            int apos = sblock.pos(i);
            int alen = sblock.size(i);
            int[] aix = sblock.indexes(i);
            double[] avals = sblock.values(i);
            for (int k = apos; k < apos + alen; k++) {
                // process zeros before current non-zero
                if (!sparseSafe)
                    for (int j = lastj + 1; j < aix[k]; j++) kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j));
                // process current non-zero
                lastj = aix[k];
                kplus.execute2(kbuff, genexec(avals[k], b, scalars, m, n, i, lastj));
            }
        }
        // process empty rows or remaining zeros
        if (!sparseSafe)
            for (int j = lastj + 1; j < n; j++) kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j));
    }
    return kbuff._sum;
}
Also used : KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject)

Example 7 with KahanFunction

use of org.apache.sysml.runtime.functionobjects.KahanFunction in project incubator-systemml by apache.

the class SpoofCellwise method executeCompressedAggSum.

private double executeCompressedAggSum(CompressedMatrixBlock a, SideInput[] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) {
    KahanFunction kplus = (KahanFunction) getAggFunction();
    KahanObject kbuff = new KahanObject(0, 0);
    KahanObject kbuff2 = new KahanObject(0, 0);
    // special case: computation over value-tuples only
    if (sparseSafe && b.length == 0 && !a.hasUncompressedColGroup()) {
        // note: all remaining groups are guaranteed ColGroupValue
        boolean entireGrp = (rl == 0 && ru == a.getNumRows());
        int maxNumVals = a.getColGroups().stream().mapToInt(g -> ((ColGroupValue) g).getNumValues()).max().orElse(0);
        int[] counts = new int[maxNumVals];
        for (ColGroup grp : a.getColGroups()) {
            ColGroupValue grpv = (ColGroupValue) grp;
            counts = entireGrp ? grpv.getCounts(counts) : grpv.getCounts(rl, ru, counts);
            for (int k = 0; k < grpv.getNumValues(); k++) {
                kbuff2.set(0, 0);
                double in = grpv.sumValues(k, kplus, kbuff2);
                double out = genexec(in, b, scalars, m, n, -1, -1);
                kplus.execute3(kbuff, out, counts[k]);
            }
        }
    } else // general case of arbitrary side inputs
    {
        Iterator<IJV> iter = a.getIterator(rl, ru, !sparseSafe);
        while (iter.hasNext()) {
            IJV cell = iter.next();
            double val = genexec(cell.getV(), b, scalars, m, n, cell.getI(), cell.getJ());
            kplus.execute2(kbuff, val);
        }
    }
    return kbuff._sum;
}
Also used : ColGroup(org.apache.sysml.runtime.compress.ColGroup) IJV(org.apache.sysml.runtime.matrix.data.IJV) KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) ColGroupValue(org.apache.sysml.runtime.compress.ColGroupValue)

Example 8 with KahanFunction

use of org.apache.sysml.runtime.functionobjects.KahanFunction in project incubator-systemml by apache.

the class SpoofCellwise method executeDenseColAggSum.

private long executeDenseColAggSum(DenseBlock a, SideInput[] b, double[] scalars, DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru) {
    // single block
    double[] lc = c.valuesAt(0);
    KahanFunction kplus = (KahanFunction) getAggFunction();
    KahanObject kbuff = new KahanObject(0, 0);
    double[] corr = new double[n];
    if (a == null && !sparseSafe) {
        for (int i = rl; i < ru; i++) for (int j = 0; j < n; j++) {
            kbuff.set(lc[j], corr[j]);
            kplus.execute2(kbuff, genexec(0, b, scalars, m, n, i, j));
            lc[j] = kbuff._sum;
            corr[j] = kbuff._correction;
        }
    } else if (a != null) {
        for (int i = rl; i < ru; i++) {
            double[] avals = a.values(i);
            int aix = a.pos(i);
            for (int j = 0; j < n; j++) {
                double aval = avals[aix + j];
                if (aval != 0 || !sparseSafe) {
                    kbuff.set(lc[j], corr[j]);
                    kplus.execute2(kbuff, genexec(aval, b, scalars, m, n, i, j));
                    lc[j] = kbuff._sum;
                    corr[j] = kbuff._correction;
                }
            }
        }
    }
    return -1;
}
Also used : KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject)

Example 9 with KahanFunction

use of org.apache.sysml.runtime.functionobjects.KahanFunction in project incubator-systemml by apache.

the class CompressedMatrixBlock method aggregateUnaryOperations.

@Override
public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int blockingFactorRow, int blockingFactorCol, MatrixIndexes indexesIn, boolean inCP) {
    // call uncompressed matrix mult if necessary
    if (!isCompressed()) {
        return super.aggregateUnaryOperations(op, result, blockingFactorRow, blockingFactorCol, indexesIn, inCP);
    }
    // check for supported operations
    if (!(op.aggOp.increOp.fn instanceof KahanPlus || op.aggOp.increOp.fn instanceof KahanPlusSq || (op.aggOp.increOp.fn instanceof Builtin && (((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN || ((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) {
        throw new DMLRuntimeException("Unary aggregates other than sum/sumsq/min/max not supported yet.");
    }
    Timing time = LOG.isDebugEnabled() ? new Timing(true) : null;
    // prepare output dimensions
    CellIndex tempCellIndex = new CellIndex(-1, -1);
    op.indexFn.computeDimension(rlen, clen, tempCellIndex);
    if (op.aggOp.correctionExists) {
        switch(op.aggOp.correctionLocation) {
            case LASTROW:
                tempCellIndex.row++;
                break;
            case LASTCOLUMN:
                tempCellIndex.column++;
                break;
            case LASTTWOROWS:
                tempCellIndex.row += 2;
                break;
            case LASTTWOCOLUMNS:
                tempCellIndex.column += 2;
                break;
            default:
                throw new DMLRuntimeException("unrecognized correctionLocation: " + op.aggOp.correctionLocation);
        }
    }
    // initialize and allocate the result
    if (result == null)
        result = new MatrixBlock(tempCellIndex.row, tempCellIndex.column, false);
    else
        result.reset(tempCellIndex.row, tempCellIndex.column, false);
    MatrixBlock ret = (MatrixBlock) result;
    ret.allocateDenseBlock();
    // special handling init value for rowmins/rowmax
    if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) {
        double val = (((Builtin) op.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX) ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
        ret.getDenseBlock().set(val);
    }
    // core unary aggregate
    if (op.getNumThreads() > 1 && getExactSizeOnDisk() > MIN_PAR_AGG_THRESHOLD) {
        // multi-threaded execution of all groups
        ArrayList<ColGroup>[] grpParts = createStaticTaskPartitioning((op.indexFn instanceof ReduceCol) ? 1 : op.getNumThreads(), false);
        ColGroupUncompressed uc = getUncompressedColGroup();
        try {
            // compute uncompressed column group in parallel (otherwise bottleneck)
            if (uc != null)
                uc.unaryAggregateOperations(op, ret);
            // compute all compressed column groups
            ExecutorService pool = CommonThreadPool.get(op.getNumThreads());
            ArrayList<UnaryAggregateTask> tasks = new ArrayList<>();
            if (op.indexFn instanceof ReduceCol && grpParts.length > 0) {
                int blklen = BitmapEncoder.getAlignedBlocksize((int) (Math.ceil((double) rlen / op.getNumThreads())));
                for (int i = 0; i < op.getNumThreads() & i * blklen < rlen; i++) tasks.add(new UnaryAggregateTask(grpParts[0], ret, i * blklen, Math.min((i + 1) * blklen, rlen), op));
            } else
                for (ArrayList<ColGroup> grp : grpParts) tasks.add(new UnaryAggregateTask(grp, ret, 0, rlen, op));
            List<Future<MatrixBlock>> rtasks = pool.invokeAll(tasks);
            pool.shutdown();
            // aggregate partial results
            if (op.indexFn instanceof ReduceAll) {
                if (op.aggOp.increOp.fn instanceof KahanFunction) {
                    KahanObject kbuff = new KahanObject(ret.quickGetValue(0, 0), 0);
                    for (Future<MatrixBlock> rtask : rtasks) {
                        double tmp = rtask.get().quickGetValue(0, 0);
                        ((KahanFunction) op.aggOp.increOp.fn).execute2(kbuff, tmp);
                    }
                    ret.quickSetValue(0, 0, kbuff._sum);
                } else {
                    double val = ret.quickGetValue(0, 0);
                    for (Future<MatrixBlock> rtask : rtasks) {
                        double tmp = rtask.get().quickGetValue(0, 0);
                        val = op.aggOp.increOp.fn.execute(val, tmp);
                    }
                    ret.quickSetValue(0, 0, val);
                }
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    } else {
        // process UC column group
        for (ColGroup grp : _colGroups) if (grp instanceof ColGroupUncompressed)
            grp.unaryAggregateOperations(op, ret);
        // process OLE/RLE column groups
        aggregateUnaryOperations(op, _colGroups, ret, 0, rlen);
    }
    // special handling zeros for rowmins/rowmax
    if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) {
        int[] rnnz = new int[rlen];
        for (ColGroup grp : _colGroups) grp.countNonZerosPerRow(rnnz, 0, rlen);
        Builtin builtin = (Builtin) op.aggOp.increOp.fn;
        for (int i = 0; i < rlen; i++) if (rnnz[i] < clen)
            ret.quickSetValue(i, 0, builtin.execute2(ret.quickGetValue(i, 0), 0));
    }
    // drop correction if necessary
    if (op.aggOp.correctionExists && inCP)
        ret.dropLastRowsOrColumns(op.aggOp.correctionLocation);
    // post-processing
    ret.recomputeNonZeros();
    if (LOG.isDebugEnabled())
        LOG.debug("Compressed uagg k=" + op.getNumThreads() + " in " + time.stop());
    return ret;
}
Also used : ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) ArrayList(java.util.ArrayList) KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) ExecutorService(java.util.concurrent.ExecutorService) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) Future(java.util.concurrent.Future) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 10 with KahanFunction

use of org.apache.sysml.runtime.functionobjects.KahanFunction in project systemml by apache.

the class SpoofCellwise method executeCompressedRowAggSum.

private long executeCompressedRowAggSum(CompressedMatrixBlock a, SideInput[] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru) {
    KahanFunction kplus = (KahanFunction) getAggFunction();
    KahanObject kbuff = new KahanObject(0, 0);
    long lnnz = 0;
    Iterator<IJV> iter = a.getIterator(rl, ru, !sparseSafe);
    while (iter.hasNext()) {
        IJV cell = iter.next();
        double val = genexec(cell.getV(), b, scalars, m, n, cell.getI(), cell.getJ());
        kbuff.set(c[cell.getI()], 0);
        kplus.execute2(kbuff, val);
        c[cell.getI()] = kbuff._sum;
    }
    for (int i = rl; i < ru; i++) lnnz += (c[i] != 0) ? 1 : 0;
    return lnnz;
}
Also used : IJV(org.apache.sysml.runtime.matrix.data.IJV) KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject)

Aggregations

KahanFunction (org.apache.sysml.runtime.functionobjects.KahanFunction)32 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)28 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)10 ValueFunction (org.apache.sysml.runtime.functionobjects.ValueFunction)10 ArrayList (java.util.ArrayList)6 ExecutorService (java.util.concurrent.ExecutorService)6 Future (java.util.concurrent.Future)6 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)6 IJV (org.apache.sysml.runtime.matrix.data.IJV)6 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)6 CompressedMatrixBlock (org.apache.sysml.runtime.compress.CompressedMatrixBlock)4 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)4 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)4 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)4 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)4 IOException (java.io.IOException)2 ColGroup (org.apache.sysml.runtime.compress.ColGroup)2 ColGroupValue (org.apache.sysml.runtime.compress.ColGroupValue)2 Timing (org.apache.sysml.runtime.controlprogram.parfor.stat.Timing)2 BuiltinCode (org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode)2