use of org.apache.sysml.runtime.codegen.SpoofOperator.SideInput in project incubator-systemml by apache.
the class RowClassMeet method execute.
@Override
public void execute() {
try {
MatrixBlock A = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
MatrixBlock B = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
int nr = Math.max(A.getNumRows(), B.getNumRows());
int nc = Math.max(A.getNumColumns(), B.getNumColumns());
MatrixBlock C = new MatrixBlock(nr, nc, false).allocateBlock();
MatrixBlock N = new MatrixBlock(nr, nc, false).allocateBlock();
double[] dC = C.getDenseBlockValues();
double[] dN = N.getDenseBlockValues();
// wrap both A and B into side inputs for efficient sparse access
SideInput sB = CodegenUtils.createSideInput(B);
boolean mv = (B.getNumRows() == 1);
int numCols = Math.min(A.getNumColumns(), B.getNumColumns());
HashMap<ClassLabel, IntArrayList> classLabelMapping = new HashMap<>();
for (int i = 0, ai = 0; i < A.getNumRows(); i++, ai += A.getNumColumns()) {
classLabelMapping.clear();
sB.reset();
if (A.isInSparseFormat()) {
if (A.getSparseBlock() == null || A.getSparseBlock().isEmpty(i))
continue;
int alen = A.getSparseBlock().size(i);
int apos = A.getSparseBlock().pos(i);
int[] aix = A.getSparseBlock().indexes(i);
double[] avals = A.getSparseBlock().values(i);
for (int k = apos; k < apos + alen; k++) {
if (aix[k] >= numCols)
break;
int bval = (int) sB.getValue(mv ? 0 : i, aix[k]);
if (bval != 0) {
ClassLabel key = new ClassLabel((int) avals[k], bval);
if (!classLabelMapping.containsKey(key))
classLabelMapping.put(key, new IntArrayList());
classLabelMapping.get(key).appendValue(aix[k]);
}
}
} else {
double[] denseBlk = A.getDenseBlockValues();
if (denseBlk == null)
break;
for (int j = 0; j < numCols; j++) {
int aVal = (int) denseBlk[ai + j];
int bVal = (int) sB.getValue(mv ? 0 : i, j);
if (aVal != 0 && bVal != 0) {
ClassLabel key = new ClassLabel(aVal, bVal);
if (!classLabelMapping.containsKey(key))
classLabelMapping.put(key, new IntArrayList());
classLabelMapping.get(key).appendValue(j);
}
}
}
int labelID = 1;
for (Entry<ClassLabel, IntArrayList> entry : classLabelMapping.entrySet()) {
int nVal = entry.getValue().size();
int[] list = entry.getValue().extractValues();
for (int k = 0, off = i * nc; k < nVal; k++) {
dN[off + list[k]] = nVal;
dC[off + list[k]] = labelID;
}
labelID++;
}
}
((Matrix) getFunctionInput(0)).getMatrixObject().release();
((Matrix) getFunctionInput(1)).getMatrixObject().release();
// prepare outputs
C.recomputeNonZeros();
C.examSparsity();
CMat = new Matrix(createOutputFilePathAndName("TMP"), nr, nc, ValueType.Double);
CMat.setMatrixDoubleArray(C, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
N.recomputeNonZeros();
N.examSparsity();
NMat = new Matrix(createOutputFilePathAndName("TMP"), nr, nc, ValueType.Double);
NMat.setMatrixDoubleArray(N, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
} catch (DMLRuntimeException | IOException e) {
throw new RuntimeException("Error while executing RowClassMeet", e);
}
}
Aggregations