Search in sources :

Example 1 with DiagIndex

use of org.apache.sysml.runtime.functionobjects.DiagIndex in project incubator-systemml by apache.

the class ReorgInstruction method processInstruction.

@Override
public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException {
    ArrayList<IndexedMatrixValue> blkList = cachedValues.get(input);
    if (blkList != null)
        for (IndexedMatrixValue in : blkList) {
            if (in == null)
                continue;
            int startRow = 0, startColumn = 0, length = 0;
            //process instruction
            if (((ReorgOperator) optr).fn instanceof DiagIndex) {
                //special diag handling (overloaded, size-dependent operation; hence decided during runtime)
                boolean V2M = (_mcIn.getRows() == 1 || _mcIn.getCols() == 1);
                //input can be row/column vector
                long rlen = Math.max(_mcIn.getRows(), _mcIn.getCols());
                //Note: for M2V we directly skip non-diagonal blocks block
                if (V2M || in.getIndexes().getRowIndex() == in.getIndexes().getColumnIndex()) {
                    if (V2M) {
                        //allocate space for the output value
                        IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass);
                        OperationsOnMatrixValues.performReorg(in.getIndexes(), in.getValue(), out.getIndexes(), out.getValue(), ((ReorgOperator) optr), startRow, startColumn, length);
                        //(only for block representation)
                        if (_outputEmptyBlocks && valueClass.equals(MatrixBlock.class)) {
                            //row index is equal to the col index
                            long diagIndex = out.getIndexes().getRowIndex();
                            long brlen = Math.max(_mcIn.getRowsPerBlock(), _mcIn.getColsPerBlock());
                            long numRowBlocks = (rlen / brlen) + ((rlen % brlen != 0) ? 1 : 0);
                            for (long rc = 1; rc <= numRowBlocks; rc++) {
                                //prevent duplicate output
                                if (rc == diagIndex)
                                    continue;
                                IndexedMatrixValue emptyIndexValue = cachedValues.holdPlace(output, valueClass);
                                int lbrlen = (int) ((rc * brlen <= rlen) ? brlen : rlen % brlen);
                                emptyIndexValue.getIndexes().setIndexes(rc, diagIndex);
                                emptyIndexValue.getValue().reset(lbrlen, out.getValue().getNumColumns(), true);
                            }
                        }
                    } else //M2V
                    {
                        //allocate space for the output value
                        IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass);
                        //compute matrix indexes
                        out.getIndexes().setIndexes(in.getIndexes().getRowIndex(), 1);
                        //compute result block
                        in.getValue().reorgOperations((ReorgOperator) optr, out.getValue(), startRow, startColumn, length);
                    }
                }
            } else if (((ReorgOperator) optr).fn instanceof RevIndex) {
                //execute reverse operation
                ArrayList<IndexedMatrixValue> out = new ArrayList<IndexedMatrixValue>();
                LibMatrixReorg.rev(in, _mcIn.getRows(), _mcIn.getRowsPerBlock(), out);
                //output indexed matrix values
                for (IndexedMatrixValue outblk : out) cachedValues.add(output, outblk);
            } else //general case (e.g., transpose)
            {
                //allocate space for the output value
                IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass);
                OperationsOnMatrixValues.performReorg(in.getIndexes(), in.getValue(), out.getIndexes(), out.getValue(), ((ReorgOperator) optr), startRow, startColumn, length);
            }
        }
}
Also used : RevIndex(org.apache.sysml.runtime.functionobjects.RevIndex) DiagIndex(org.apache.sysml.runtime.functionobjects.DiagIndex) ArrayList(java.util.ArrayList) ReorgOperator(org.apache.sysml.runtime.matrix.operators.ReorgOperator) IndexedMatrixValue(org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue)

Example 2 with DiagIndex

use of org.apache.sysml.runtime.functionobjects.DiagIndex in project incubator-systemml by apache.

the class MatrixBlock method reorgOperations.

@Override
public MatrixValue reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) throws DMLRuntimeException {
    if (!(op.fn instanceof SwapIndex || op.fn instanceof DiagIndex || op.fn instanceof SortIndex || op.fn instanceof RevIndex))
        throw new DMLRuntimeException("the current reorgOperations cannot support: " + op.fn.getClass() + ".");
    MatrixBlock result = checkType(ret);
    //compute output dimensions and sparsity flag
    CellIndex tempCellIndex = new CellIndex(-1, -1);
    op.fn.computeDimension(rlen, clen, tempCellIndex);
    boolean sps = evalSparseFormatInMemory(tempCellIndex.row, tempCellIndex.column, nonZeros);
    //prepare output matrix block w/ right meta data
    if (result == null)
        result = new MatrixBlock(tempCellIndex.row, tempCellIndex.column, sps, nonZeros);
    else
        result.reset(tempCellIndex.row, tempCellIndex.column, sps, nonZeros);
    if (LibMatrixReorg.isSupportedReorgOperator(op)) {
        //SPECIAL case (operators with special performance requirements, 
        //or size-dependent special behavior)
        //currently supported opcodes: r', rdiag, rsort
        LibMatrixReorg.reorg(this, result, op);
    } else {
        //GENERIC case (any reorg operator)
        CellIndex temp = new CellIndex(0, 0);
        if (sparse) {
            if (sparseBlock != null) {
                for (int r = 0; r < Math.min(rlen, sparseBlock.numRows()); r++) {
                    if (sparseBlock.isEmpty(r))
                        continue;
                    int apos = sparseBlock.pos(r);
                    int alen = sparseBlock.size(r);
                    int[] aix = sparseBlock.indexes(r);
                    double[] avals = sparseBlock.values(r);
                    for (int i = apos; i < apos + alen; i++) {
                        tempCellIndex.set(r, aix[i]);
                        op.fn.execute(tempCellIndex, temp);
                        result.appendValue(temp.row, temp.column, avals[i]);
                    }
                }
            }
        } else {
            if (denseBlock != null) {
                if (//SPARSE<-DENSE
                result.isInSparseFormat()) {
                    double[] a = denseBlock;
                    for (int i = 0, aix = 0; i < rlen; i++) for (int j = 0; j < clen; j++, aix++) {
                        temp.set(i, j);
                        op.fn.execute(temp, temp);
                        result.appendValue(temp.row, temp.column, a[aix]);
                    }
                } else //DENSE<-DENSE
                {
                    result.allocateDenseBlock();
                    Arrays.fill(result.denseBlock, 0);
                    double[] a = denseBlock;
                    double[] c = result.denseBlock;
                    int n = result.clen;
                    for (int i = 0, aix = 0; i < rlen; i++) for (int j = 0; j < clen; j++, aix++) {
                        temp.set(i, j);
                        op.fn.execute(temp, temp);
                        c[temp.row * n + temp.column] = a[aix];
                    }
                    result.nonZeros = nonZeros;
                }
            }
        }
    }
    return result;
}
Also used : SwapIndex(org.apache.sysml.runtime.functionobjects.SwapIndex) RevIndex(org.apache.sysml.runtime.functionobjects.RevIndex) DiagIndex(org.apache.sysml.runtime.functionobjects.DiagIndex) SortIndex(org.apache.sysml.runtime.functionobjects.SortIndex) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

DiagIndex (org.apache.sysml.runtime.functionobjects.DiagIndex)2 RevIndex (org.apache.sysml.runtime.functionobjects.RevIndex)2 ArrayList (java.util.ArrayList)1 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 SortIndex (org.apache.sysml.runtime.functionobjects.SortIndex)1 SwapIndex (org.apache.sysml.runtime.functionobjects.SwapIndex)1 IndexedMatrixValue (org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue)1 ReorgOperator (org.apache.sysml.runtime.matrix.operators.ReorgOperator)1