use of org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction in project systemml by apache.
the class UaggOuterChainSPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext) ec;
boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || _uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp));
String rddVar = (rightCached) ? input1.getName() : input2.getName();
String bcastVar = (rightCached) ? input2.getName() : input1.getName();
// get rdd input
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
boolean noKeyChange = preservesPartitioning(mcIn, _uaggOp.indexFn);
// execute UAggOuterChain instruction
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
if (LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp)) {
// create sorted broadcast matrix
MatrixBlock mb = sec.getMatrixInput(bcastVar, getExtendedOpcode());
sec.releaseMatrixInput(bcastVar, getExtendedOpcode());
// prevent lineage tracking
bcastVar = null;
double[] vmb = DataConverter.convertToDoubleVector(mb);
Broadcast<int[]> bvi = null;
if (_uaggOp.aggOp.increOp.fn instanceof Builtin) {
int[] vix = LibMatrixOuterAgg.prepareRowIndices(mb.getNumColumns(), vmb, _bOp, _uaggOp);
bvi = sec.getSparkContext().broadcast(vix);
} else
Arrays.sort(vmb);
Broadcast<double[]> bv = sec.getSparkContext().broadcast(vmb);
// partitioning-preserving map-to-pair (under constraints)
out = in1.mapPartitionsToPair(new RDDMapUAggOuterChainFunction(bv, bvi, _bOp, _uaggOp), noKeyChange);
} else {
PartitionedBroadcast<MatrixBlock> bv = sec.getBroadcastForVariable(bcastVar);
// partitioning-preserving map-to-pair (under constraints)
out = in1.mapPartitionsToPair(new RDDMapGenUAggOuterChainFunction(bv, _uaggOp, _aggOp, _bOp, mcIn), noKeyChange);
}
// final aggregation if required
if (// RC AGG (output is scalar)
_uaggOp.indexFn instanceof ReduceAll) {
MatrixBlock tmp = RDDAggregateUtils.aggStable(out, _aggOp);
// drop correction after aggregation
tmp.dropLastRowsOrColumns(_aggOp.correctionLocation);
// put output block into symbol table (no lineage because single block)
sec.setMatrixOutput(output.getName(), tmp, getExtendedOpcode());
} else // R/C AGG (output is rdd)
{
// put output RDD handle into symbol table
updateUnaryAggOutputMatrixCharacteristics(sec);
if (_uaggOp.aggOp.correctionExists)
out = out.mapValues(new AggregateDropCorrectionFunction(_uaggOp.aggOp));
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
if (bcastVar != null)
sec.addLineageBroadcast(output.getName(), bcastVar);
}
}
Aggregations