use of org.apache.sysml.runtime.codegen.SpoofMultiAggregate in project incubator-systemml by apache.
the class SpoofSPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
SparkExecutionContext sec = (SparkExecutionContext) ec;
//get input rdd and variable name
ArrayList<String> bcVars = new ArrayList<String>();
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(_in[0].getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
//simple case: map-side only operation (one rdd input, broadcast all)
//keep track of broadcast variables
ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>();
ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
for (int i = 1; i < _in.length; i++) {
if (_in[i].getDataType() == DataType.MATRIX) {
bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName()));
bcVars.add(_in[i].getName());
} else if (_in[i].getDataType() == DataType.SCALAR) {
//note: even if literal, it might be compiled as scalar placeholder
scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral()));
}
}
//initialize Spark Operator
if (// cellwise operator
_class.getSuperclass() == SpoofCellwise.class) {
SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class);
AggregateOperator aggop = getAggregateOperator(op.getAggOp());
if (_out.getDataType() == DataType.MATRIX) {
out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
if (op.getCellType() == CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) {
//TODO investigate if some other side effect of correct blocks
if (out.partitions().size() > mcIn.getNumRowBlocks())
out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int) mcIn.getNumRowBlocks(), false);
else
out = RDDAggregateUtils.aggByKeyStable(out, aggop, false);
}
sec.setRDDHandleForVariable(_out.getName(), out);
//maintain lineage information for output rdd
sec.addLineageRDD(_out.getName(), _in[0].getName());
for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
//update matrix characteristics
updateOutputMatrixCharacteristics(sec, op);
} else {
//SCALAR
out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop);
sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0)));
}
} else if (_class.getSuperclass() == SpoofMultiAggregate.class) {
SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class);
AggOp[] aggOps = op.getAggOps();
MatrixBlock tmpMB = in.mapToPair(new MultiAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars)).values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps));
sec.setMatrixOutput(_out.getName(), tmpMB);
return;
} else if (// outer product operator
_class.getSuperclass() == SpoofOuterProduct.class) {
if (_out.getDataType() == DataType.MATRIX) {
SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class);
OutProdType type = ((SpoofOuterProduct) op).getOuterProdType();
//update matrix characteristics
updateOutputMatrixCharacteristics(sec, op);
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
if (type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT) {
//TODO investigate if some other side effect of correct blocks
if (in.partitions().size() > mcOut.getNumRowBlocks() * mcOut.getNumColBlocks())
out = RDDAggregateUtils.sumByKeyStable(out, (int) (mcOut.getNumRowBlocks() * mcOut.getNumColBlocks()), false);
else
out = RDDAggregateUtils.sumByKeyStable(out, false);
}
sec.setRDDHandleForVariable(_out.getName(), out);
//maintain lineage information for output rdd
sec.addLineageRDD(_out.getName(), _in[0].getName());
for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
} else {
out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0)));
}
} else if (_class.getSuperclass() == SpoofRowwise.class) {
//row aggregate operator
SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class);
RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars, (int) mcIn.getCols());
out = in.mapPartitionsToPair(fmmc, op.getRowType() == RowType.ROW_AGG);
if (op.getRowType().isColumnAgg()) {
MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out);
sec.setMatrixOutput(_out.getName(), tmpMB);
} else //row-agg or no-agg
{
if (op.getRowType() == RowType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) {
//TODO investigate if some other side effect of correct blocks
if (out.partitions().size() > mcIn.getNumRowBlocks())
out = RDDAggregateUtils.sumByKeyStable(out, (int) mcIn.getNumRowBlocks(), false);
else
out = RDDAggregateUtils.sumByKeyStable(out, false);
}
sec.setRDDHandleForVariable(_out.getName(), out);
//maintain lineage information for output rdd
sec.addLineageRDD(_out.getName(), _in[0].getName());
for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
//update matrix characteristics
updateOutputMatrixCharacteristics(sec, op);
}
return;
} else {
throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark");
}
}
Aggregations