Search in sources :

Example 1 with SpoofOuterProduct

use of org.apache.sysml.runtime.codegen.SpoofOuterProduct in project incubator-systemml by apache.

the class SpoofSPInstruction method updateOutputMatrixCharacteristics.

private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, SpoofOperator op) throws DMLRuntimeException {
    if (op instanceof SpoofCellwise) {
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
        if (((SpoofCellwise) op).getCellType() == CellType.ROW_AGG)
            mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
        else if (((SpoofCellwise) op).getCellType() == CellType.NO_AGG)
            mcOut.set(mcIn);
    } else if (op instanceof SpoofOuterProduct) {
        //X
        MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(_in[0].getName());
        //U
        MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(_in[1].getName());
        //V
        MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(_in[2].getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
        OutProdType type = ((SpoofOuterProduct) op).getOuterProdType();
        if (type == OutProdType.CELLWISE_OUTER_PRODUCT)
            mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
        else if (type == OutProdType.LEFT_OUTER_PRODUCT)
            mcOut.set(mcIn3.getRows(), mcIn3.getCols(), mcIn3.getRowsPerBlock(), mcIn3.getColsPerBlock());
        else if (type == OutProdType.RIGHT_OUTER_PRODUCT)
            mcOut.set(mcIn2.getRows(), mcIn2.getCols(), mcIn2.getRowsPerBlock(), mcIn2.getColsPerBlock());
    } else if (op instanceof SpoofRowwise) {
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
        RowType type = ((SpoofRowwise) op).getRowType();
        if (type == RowType.NO_AGG)
            mcOut.set(mcIn);
        else if (type == RowType.ROW_AGG)
            mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
        else if (type == RowType.COL_AGG)
            mcOut.set(1, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
        else if (type == RowType.COL_AGG_T)
            mcOut.set(mcIn.getCols(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
    }
}
Also used : OutProdType(org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType) SpoofRowwise(org.apache.sysml.runtime.codegen.SpoofRowwise) RowType(org.apache.sysml.runtime.codegen.SpoofRowwise.RowType) SpoofOuterProduct(org.apache.sysml.runtime.codegen.SpoofOuterProduct) SpoofCellwise(org.apache.sysml.runtime.codegen.SpoofCellwise) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics)

Example 2 with SpoofOuterProduct

use of org.apache.sysml.runtime.codegen.SpoofOuterProduct in project incubator-systemml by apache.

the class SpoofSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    //get input rdd and variable name
    ArrayList<String> bcVars = new ArrayList<String>();
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(_in[0].getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    //simple case: map-side only operation (one rdd input, broadcast all)
    //keep track of broadcast variables
    ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>();
    ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
    for (int i = 1; i < _in.length; i++) {
        if (_in[i].getDataType() == DataType.MATRIX) {
            bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName()));
            bcVars.add(_in[i].getName());
        } else if (_in[i].getDataType() == DataType.SCALAR) {
            //note: even if literal, it might be compiled as scalar placeholder
            scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral()));
        }
    }
    //initialize Spark Operator
    if (// cellwise operator
    _class.getSuperclass() == SpoofCellwise.class) {
        SpoofCellwise op = (SpoofCellwise) CodegenUtils.createInstance(_class);
        AggregateOperator aggop = getAggregateOperator(op.getAggOp());
        if (_out.getDataType() == DataType.MATRIX) {
            out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            if (op.getCellType() == CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) {
                //TODO investigate if some other side effect of correct blocks
                if (out.partitions().size() > mcIn.getNumRowBlocks())
                    out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int) mcIn.getNumRowBlocks(), false);
                else
                    out = RDDAggregateUtils.aggByKeyStable(out, aggop, false);
            }
            sec.setRDDHandleForVariable(_out.getName(), out);
            //maintain lineage information for output rdd
            sec.addLineageRDD(_out.getName(), _in[0].getName());
            for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
            //update matrix characteristics
            updateOutputMatrixCharacteristics(sec, op);
        } else {
            //SCALAR
            out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop);
            sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0)));
        }
    } else if (_class.getSuperclass() == SpoofMultiAggregate.class) {
        SpoofMultiAggregate op = (SpoofMultiAggregate) CodegenUtils.createInstance(_class);
        AggOp[] aggOps = op.getAggOps();
        MatrixBlock tmpMB = in.mapToPair(new MultiAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars)).values().fold(new MatrixBlock(), new MultiAggAggregateFunction(aggOps));
        sec.setMatrixOutput(_out.getName(), tmpMB);
        return;
    } else if (// outer product operator
    _class.getSuperclass() == SpoofOuterProduct.class) {
        if (_out.getDataType() == DataType.MATRIX) {
            SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class);
            OutProdType type = ((SpoofOuterProduct) op).getOuterProdType();
            //update matrix characteristics
            updateOutputMatrixCharacteristics(sec, op);
            MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
            out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            if (type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT) {
                //TODO investigate if some other side effect of correct blocks
                if (in.partitions().size() > mcOut.getNumRowBlocks() * mcOut.getNumColBlocks())
                    out = RDDAggregateUtils.sumByKeyStable(out, (int) (mcOut.getNumRowBlocks() * mcOut.getNumColBlocks()), false);
                else
                    out = RDDAggregateUtils.sumByKeyStable(out, false);
            }
            sec.setRDDHandleForVariable(_out.getName(), out);
            //maintain lineage information for output rdd
            sec.addLineageRDD(_out.getName(), _in[0].getName());
            for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
        } else {
            out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
            MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
            sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0)));
        }
    } else if (_class.getSuperclass() == SpoofRowwise.class) {
        //row aggregate operator
        SpoofRowwise op = (SpoofRowwise) CodegenUtils.createInstance(_class);
        RowwiseFunction fmmc = new RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars, (int) mcIn.getCols());
        out = in.mapPartitionsToPair(fmmc, op.getRowType() == RowType.ROW_AGG);
        if (op.getRowType().isColumnAgg()) {
            MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out);
            sec.setMatrixOutput(_out.getName(), tmpMB);
        } else //row-agg or no-agg 
        {
            if (op.getRowType() == RowType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock()) {
                //TODO investigate if some other side effect of correct blocks
                if (out.partitions().size() > mcIn.getNumRowBlocks())
                    out = RDDAggregateUtils.sumByKeyStable(out, (int) mcIn.getNumRowBlocks(), false);
                else
                    out = RDDAggregateUtils.sumByKeyStable(out, false);
            }
            sec.setRDDHandleForVariable(_out.getName(), out);
            //maintain lineage information for output rdd
            sec.addLineageRDD(_out.getName(), _in[0].getName());
            for (String bcVar : bcVars) sec.addLineageBroadcast(_out.getName(), bcVar);
            //update matrix characteristics
            updateOutputMatrixCharacteristics(sec, op);
        }
        return;
    } else {
        throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark");
    }
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) SpoofRowwise(org.apache.sysml.runtime.codegen.SpoofRowwise) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) ArrayList(java.util.ArrayList) SpoofOperator(org.apache.sysml.runtime.codegen.SpoofOperator) ScalarObject(org.apache.sysml.runtime.instructions.cp.ScalarObject) PartitionedBroadcast(org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) SpoofMultiAggregate(org.apache.sysml.runtime.codegen.SpoofMultiAggregate) OutProdType(org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) SpoofOuterProduct(org.apache.sysml.runtime.codegen.SpoofOuterProduct) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) SpoofCellwise(org.apache.sysml.runtime.codegen.SpoofCellwise)

Aggregations

SpoofCellwise (org.apache.sysml.runtime.codegen.SpoofCellwise)2 SpoofOuterProduct (org.apache.sysml.runtime.codegen.SpoofOuterProduct)2 OutProdType (org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType)2 SpoofRowwise (org.apache.sysml.runtime.codegen.SpoofRowwise)2 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)2 ArrayList (java.util.ArrayList)1 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 SpoofMultiAggregate (org.apache.sysml.runtime.codegen.SpoofMultiAggregate)1 SpoofOperator (org.apache.sysml.runtime.codegen.SpoofOperator)1 RowType (org.apache.sysml.runtime.codegen.SpoofRowwise.RowType)1 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)1 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)1 ScalarObject (org.apache.sysml.runtime.instructions.cp.ScalarObject)1 PartitionedBroadcast (org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast)1 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)1 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)1 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)1