Search in sources :

Example 1 with SpoofMultiAggregate

use of org.apache.sysml.runtime.codegen.SpoofMultiAggregate in project incubator-systemml by apache.

the class SpoofSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    //get input rdd and variable name
    ArrayList<String> bcVars = new ArrayList<String>();
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(_in[0].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    //simple case: map-side only operation (one rdd input, broadcast all)
    //keep track of broadcast variables
    ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>();
    ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
    for (int i = 1; i < _in.length; i++) {
        if (_in[i].getDataType() == DataType.MATRIX) {
            bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName()));
            bcVars.add(_in[i].getName());
        } else if (_in[i].getDataType() == DataType.SCALAR) {
            //note: even if literal, it might be compiled as scalar placeholder
            scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral()));
        }
    }
    //initialize Spark Operator
    if (// cellwise operator
    _class.getSuperclass() == SpoofCellwise.class) {
        SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class);
        AggregateOperator aggop = getAggregateOperator(op.getAggOp());
        if (_out.getDataType() == DataType.MATRIX) {
            out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            if (op.getCellType() == CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) {
                //TODO investigate if some other side effect of correct blocks
                if (out.partitions().size() > mcIn.getNumRowBlocks())
                    out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int) mcIn.getNumRowBlocks(), false);
                else
                    out = RDDAggregateUtils.aggByKeyStable(out, aggop, false);
            }
            sec.setRDDHandleForVariable(_out.getName(), out);
            //maintain lineage information for output rdd
            sec.addLineageRDD(_out.getName(), _in[0].getName());
            for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
            //update matrix characteristics
            updateOutputMatrixCharacteristics(sec, op);
        } else {
            //SCALAR
            out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop);
            sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0)));
        }
    } else if (_class.getSuperclass() == SpoofMultiAggregate.class) {
        SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class);
        AggOp[] aggOps = op.getAggOps();
        MatrixBlock tmpMB = in.mapToPair(new MultiAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars)).values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps));
        sec.setMatrixOutput(_out.getName(), tmpMB);
        return;
    } else if (// outer product operator
    _class.getSuperclass() == SpoofOuterProduct.class) {
        if (_out.getDataType() == DataType.MATRIX) {
            SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class);
            OutProdType type = ((SpoofOuterProduct) op).getOuterProdType();
            //update matrix characteristics
            updateOutputMatrixCharacteristics(sec, op);
            MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
            out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            if (type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT) {
                //TODO investigate if some other side effect of correct blocks
                if (in.partitions().size() > mcOut.getNumRowBlocks() * mcOut.getNumColBlocks())
                    out = RDDAggregateUtils.sumByKeyStable(out, (int) (mcOut.getNumRowBlocks() * mcOut.getNumColBlocks()), false);
                else
                    out = RDDAggregateUtils.sumByKeyStable(out, false);
            }
            sec.setRDDHandleForVariable(_out.getName(), out);
            //maintain lineage information for output rdd
            sec.addLineageRDD(_out.getName(), _in[0].getName());
            for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
        } else {
            out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
            sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0)));
        }
    } else if (_class.getSuperclass() == SpoofRowwise.class) {
        //row aggregate operator
        SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class);
        RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars, (int) mcIn.getCols());
        out = in.mapPartitionsToPair(fmmc, op.getRowType() == RowType.ROW_AGG);
        if (op.getRowType().isColumnAgg()) {
            MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out);
            sec.setMatrixOutput(_out.getName(), tmpMB);
        } else //row-agg or no-agg 
        {
            if (op.getRowType() == RowType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) {
                //TODO investigate if some other side effect of correct blocks
                if (out.partitions().size() > mcIn.getNumRowBlocks())
                    out = RDDAggregateUtils.sumByKeyStable(out, (int) mcIn.getNumRowBlocks(), false);
                else
                    out = RDDAggregateUtils.sumByKeyStable(out, false);
            }
            sec.setRDDHandleForVariable(_out.getName(), out);
            //maintain lineage information for output rdd
            sec.addLineageRDD(_out.getName(), _in[0].getName());
            for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
            //update matrix characteristics
            updateOutputMatrixCharacteristics(sec, op);
        }
        return;
    } else {
        throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark");
    }
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) SpoofRowwise(org.apache.sysml.runtime.codegen.SpoofRowwise) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) ArrayList(java.util.ArrayList) SpoofOperator(org.apache.sysml.runtime.codegen.SpoofOperator) ScalarObject(org.apache.sysml.runtime.instructions.cp.ScalarObject) PartitionedBroadcast(org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) SpoofMultiAggregate(org.apache.sysml.runtime.codegen.SpoofMultiAggregate) OutProdType(org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) SpoofOuterProduct(org.apache.sysml.runtime.codegen.SpoofOuterProduct) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) SpoofCellwise(org.apache.sysml.runtime.codegen.SpoofCellwise)

Aggregations

ArrayList (java.util.ArrayList)1 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 SpoofCellwise (org.apache.sysml.runtime.codegen.SpoofCellwise)1 SpoofMultiAggregate (org.apache.sysml.runtime.codegen.SpoofMultiAggregate)1 SpoofOperator (org.apache.sysml.runtime.codegen.SpoofOperator)1 SpoofOuterProduct (org.apache.sysml.runtime.codegen.SpoofOuterProduct)1 OutProdType (org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType)1 SpoofRowwise (org.apache.sysml.runtime.codegen.SpoofRowwise)1 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)1 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)1 ScalarObject (org.apache.sysml.runtime.instructions.cp.ScalarObject)1 PartitionedBroadcast (org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast)1 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)1 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)1 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)1 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)1