Search in sources :

Example 1 with SideInput

use of org.apache.sysml.runtime.codegen.SpoofOperator.SideInput in project incubator-systemml by apache.

the class RowClassMeet method execute.

@Override
public void execute() {
    try {
        MatrixBlock A = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
        MatrixBlock B = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
        int nr = Math.max(A.getNumRows(), B.getNumRows());
        int nc = Math.max(A.getNumColumns(), B.getNumColumns());
        MatrixBlock C = new MatrixBlock(nr, nc, false).allocateBlock();
        MatrixBlock N = new MatrixBlock(nr, nc, false).allocateBlock();
        double[] dC = C.getDenseBlockValues();
        double[] dN = N.getDenseBlockValues();
        // wrap both A and B into side inputs for efficient sparse access
        SideInput sB = CodegenUtils.createSideInput(B);
        boolean mv = (B.getNumRows() == 1);
        int numCols = Math.min(A.getNumColumns(), B.getNumColumns());
        HashMap<ClassLabel, IntArrayList> classLabelMapping = new HashMap<>();
        for (int i = 0, ai = 0; i < A.getNumRows(); i++, ai += A.getNumColumns()) {
            classLabelMapping.clear();
            sB.reset();
            if (A.isInSparseFormat()) {
                if (A.getSparseBlock() == null || A.getSparseBlock().isEmpty(i))
                    continue;
                int alen = A.getSparseBlock().size(i);
                int apos = A.getSparseBlock().pos(i);
                int[] aix = A.getSparseBlock().indexes(i);
                double[] avals = A.getSparseBlock().values(i);
                for (int k = apos; k < apos + alen; k++) {
                    if (aix[k] >= numCols)
                        break;
                    int bval = (int) sB.getValue(mv ? 0 : i, aix[k]);
                    if (bval != 0) {
                        ClassLabel key = new ClassLabel((int) avals[k], bval);
                        if (!classLabelMapping.containsKey(key))
                            classLabelMapping.put(key, new IntArrayList());
                        classLabelMapping.get(key).appendValue(aix[k]);
                    }
                }
            } else {
                double[] denseBlk = A.getDenseBlockValues();
                if (denseBlk == null)
                    break;
                for (int j = 0; j < numCols; j++) {
                    int aVal = (int) denseBlk[ai + j];
                    int bVal = (int) sB.getValue(mv ? 0 : i, j);
                    if (aVal != 0 && bVal != 0) {
                        ClassLabel key = new ClassLabel(aVal, bVal);
                        if (!classLabelMapping.containsKey(key))
                            classLabelMapping.put(key, new IntArrayList());
                        classLabelMapping.get(key).appendValue(j);
                    }
                }
            }
            int labelID = 1;
            for (Entry<ClassLabel, IntArrayList> entry : classLabelMapping.entrySet()) {
                int nVal = entry.getValue().size();
                int[] list = entry.getValue().extractValues();
                for (int k = 0, off = i * nc; k < nVal; k++) {
                    dN[off + list[k]] = nVal;
                    dC[off + list[k]] = labelID;
                }
                labelID++;
            }
        }
        ((Matrix) getFunctionInput(0)).getMatrixObject().release();
        ((Matrix) getFunctionInput(1)).getMatrixObject().release();
        // prepare outputs
        C.recomputeNonZeros();
        C.examSparsity();
        CMat = new Matrix(createOutputFilePathAndName("TMP"), nr, nc, ValueType.Double);
        CMat.setMatrixDoubleArray(C, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
        N.recomputeNonZeros();
        N.examSparsity();
        NMat = new Matrix(createOutputFilePathAndName("TMP"), nr, nc, ValueType.Double);
        NMat.setMatrixDoubleArray(N, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
    } catch (DMLRuntimeException | IOException e) {
        throw new RuntimeException("Error while executing RowClassMeet", e);
    }
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) HashMap(java.util.HashMap) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) Matrix(org.apache.sysml.udf.Matrix) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IntArrayList(org.apache.sysml.runtime.compress.utils.IntArrayList) SideInput(org.apache.sysml.runtime.codegen.SpoofOperator.SideInput)

Aggregations

IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 SideInput (org.apache.sysml.runtime.codegen.SpoofOperator.SideInput)1 IntArrayList (org.apache.sysml.runtime.compress.utils.IntArrayList)1 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)1 Matrix (org.apache.sysml.udf.Matrix)1