use of org.apache.sysml.lops.Ctable in project incubator-systemml by apache.
the class CtableCPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) {
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
MatrixBlock matBlock2 = null, wtBlock = null;
double cst1, cst2;
CTableMap resultMap = new CTableMap(EntryType.INT);
MatrixBlock resultBlock = null;
Ctable.OperationTypes ctableOp = findCtableOperation();
ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue());
long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue());
boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1);
if (outputDimsKnown) {
int inputRows = matBlock1.getNumRows();
int inputCols = matBlock1.getNumColumns();
boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows * inputCols);
// blocks because it would implicitly turn the O(N) algorithm into O(N log N).
if (!sparse)
resultBlock = new MatrixBlock((int) outputDim1, (int) outputDim2, false);
}
if (_isExpand) {
resultBlock = new MatrixBlock(matBlock1.getNumRows(), Integer.MAX_VALUE, true);
}
switch(ctableOp) {
case // (VECTOR)
CTABLE_TRANSFORM:
// F=ctable(A,B,W)
matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode());
matBlock1.ctableOperations((SimpleOperator) _optr, matBlock2, wtBlock, resultMap, resultBlock);
break;
case // (VECTOR/MATRIX)
CTABLE_TRANSFORM_SCALAR_WEIGHT:
// F = ctable(A,B) or F = ctable(A,B,1)
matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
matBlock1.ctableOperations((SimpleOperator) _optr, matBlock2, cst1, _ignoreZeros, resultMap, resultBlock);
break;
case // (VECTOR)
CTABLE_EXPAND_SCALAR_WEIGHT:
// F = ctable(seq,A) or F = ctable(seq,B,1)
matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode());
cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
// only resultBlock.rlen known, resultBlock.clen set in operation
matBlock1.ctableOperations((SimpleOperator) _optr, matBlock2, cst1, resultBlock);
break;
case // (VECTOR)
CTABLE_TRANSFORM_HISTOGRAM:
// F=ctable(A,1) or F = ctable(A,1,1)
cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
cst2 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
matBlock1.ctableOperations((SimpleOperator) _optr, cst1, cst2, resultMap, resultBlock);
break;
case // (VECTOR)
CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM:
// F=ctable(A,1,W)
wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode());
cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
matBlock1.ctableOperations((SimpleOperator) _optr, cst1, wtBlock, resultMap, resultBlock);
break;
default:
throw new DMLRuntimeException("Encountered an invalid ctable operation (" + ctableOp + ") while executing instruction: " + this.toString());
}
if (input1.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
if (input2.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input2.getName(), getExtendedOpcode());
if (input3.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input3.getName(), getExtendedOpcode());
if (resultBlock == null) {
// decided for hash-aggregation just to prevent inefficiency in case of sparse outputs.
if (outputDimsKnown)
resultBlock = DataConverter.convertToMatrixBlock(resultMap, (int) outputDim1, (int) outputDim2);
else
resultBlock = DataConverter.convertToMatrixBlock(resultMap);
} else
resultBlock.examSparsity();
// such as ctable expand (guarded by released input memory)
if (checkGuardedRepresentationChange(matBlock1, matBlock2, resultBlock)) {
resultBlock.examSparsity();
}
ec.setMatrixOutput(output.getName(), resultBlock, getExtendedOpcode());
}
Aggregations