Search in sources :

Example 1 with PartialAggregate

use of org.apache.sysml.lops.PartialAggregate in project incubator-systemml by apache.

the class AggUnaryOp method constructLops.

@Override
public Lop constructLops() throws HopsException, LopsException {
    //return already created lops
    if (getLops() != null)
        return getLops();
    try {
        ExecType et = optFindExecType();
        Hop input = getInput().get(0);
        if (et == ExecType.CP) {
            Lop agg1 = null;
            if (isTernaryAggregateRewriteApplicable(et)) {
                agg1 = constructLopsTernaryAggregateRewrite(et);
            } else if (isUnaryAggregateOuterCPRewriteApplicable()) {
                OperationTypes op = HopsAgg2Lops.get(_op);
                DirectionTypes dir = HopsDirection2Lops.get(_direction);
                BinaryOp binput = (BinaryOp) getInput().get(0);
                agg1 = new UAggOuterChain(binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP);
                PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
                if (getDataType() == DataType.SCALAR) {
                    UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
                    unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
                    setLineNumbers(unary1);
                    setLops(unary1);
                }
            } else {
                //general case		
                int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET)) {
                    // Only implemented methods for GPU
                    if ((_op == AggOp.SUM && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.SUM_SQ && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.MAX && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.MIN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.MEAN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.VAR && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col)) || (_op == AggOp.PROD && (_direction == Direction.RowCol))) {
                        et = ExecType.GPU;
                        k = 1;
                    }
                }
                agg1 = new PartialAggregate(input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(), getValueType(), et, k);
            }
            setOutputDimensions(agg1);
            setLineNumbers(agg1);
            setLops(agg1);
            if (getDataType() == DataType.SCALAR) {
                agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz());
            }
        } else if (et == ExecType.MR) {
            OperationTypes op = HopsAgg2Lops.get(_op);
            DirectionTypes dir = HopsDirection2Lops.get(_direction);
            //unary aggregate operation
            Lop transform1 = null;
            if (isUnaryAggregateOuterRewriteApplicable()) {
                BinaryOp binput = (BinaryOp) getInput().get(0);
                transform1 = new UAggOuterChain(binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.MR);
                PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
            } else //default
            {
                transform1 = new PartialAggregate(input.constructLops(), op, dir, DataType.MATRIX, getValueType());
                ((PartialAggregate) transform1).setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock());
            }
            setLineNumbers(transform1);
            //aggregation if required
            Lop aggregate = null;
            Group group1 = null;
            Aggregate agg1 = null;
            if (requiresAggregation(input, _direction) || transform1 instanceof UAggOuterChain) {
                group1 = new Group(transform1, Group.OperationTypes.Sort, DataType.MATRIX, getValueType());
                group1.getOutputParameters().setDimensions(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz());
                setLineNumbers(group1);
                agg1 = new Aggregate(group1, HopsAgg2Lops.get(_op), DataType.MATRIX, getValueType(), et);
                agg1.getOutputParameters().setDimensions(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz());
                agg1.setupCorrectionLocation(PartialAggregate.getCorrectionLocation(op, dir));
                setLineNumbers(agg1);
                aggregate = agg1;
            } else {
                ((PartialAggregate) transform1).setDropCorrection();
                aggregate = transform1;
            }
            setLops(aggregate);
            //cast if required
            if (getDataType() == DataType.SCALAR) {
                // Set the dimensions of PartialAggregate LOP based on the
                // direction in which aggregation is performed
                PartialAggregate.setDimensionsBasedOnDirection(transform1, input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
                if (group1 != null && agg1 != null) {
                    //if aggregation required
                    group1.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz());
                    agg1.getOutputParameters().setDimensions(1, 1, input.getRowsInBlock(), input.getColsInBlock(), getNnz());
                }
                UnaryCP unary1 = new UnaryCP(aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
                unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
                setLineNumbers(unary1);
                setLops(unary1);
            }
        } else if (et == ExecType.SPARK) {
            OperationTypes op = HopsAgg2Lops.get(_op);
            DirectionTypes dir = HopsDirection2Lops.get(_direction);
            //unary aggregate
            if (isTernaryAggregateRewriteApplicable(et)) {
                Lop aggregate = constructLopsTernaryAggregateRewrite(et);
                //0x0 (scalar)
                setOutputDimensions(aggregate);
                setLineNumbers(aggregate);
                setLops(aggregate);
            } else if (isUnaryAggregateOuterSPRewriteApplicable()) {
                BinaryOp binput = (BinaryOp) getInput().get(0);
                Lop transform1 = new UAggOuterChain(binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.SPARK);
                PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
                setLineNumbers(transform1);
                setLops(transform1);
                if (getDataType() == DataType.SCALAR) {
                    UnaryCP unary1 = new UnaryCP(transform1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
                    unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
                    setLineNumbers(unary1);
                    setLops(unary1);
                }
            } else //default
            {
                boolean needAgg = requiresAggregation(input, _direction);
                SparkAggType aggtype = getSparkUnaryAggregationType(needAgg);
                PartialAggregate aggregate = new PartialAggregate(input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), DataType.MATRIX, getValueType(), aggtype, et);
                aggregate.setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock());
                setLineNumbers(aggregate);
                setLops(aggregate);
                if (getDataType() == DataType.SCALAR) {
                    UnaryCP unary1 = new UnaryCP(aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
                    unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
                    setLineNumbers(unary1);
                    setLops(unary1);
                }
            }
        }
    } catch (Exception e) {
        throw new HopsException(this.printErrorLocation() + "In AggUnary Hop, error constructing Lops ", e);
    }
    //add reblock/checkpoint lops if necessary
    constructAndSetLopsDataFlowProperties();
    //return created lops
    return getLops();
}
Also used : PartialAggregate(org.apache.sysml.lops.PartialAggregate) Group(org.apache.sysml.lops.Group) SparkAggType(org.apache.sysml.hops.AggBinaryOp.SparkAggType) MultiThreadedHop(org.apache.sysml.hops.Hop.MultiThreadedHop) Lop(org.apache.sysml.lops.Lop) LopsException(org.apache.sysml.lops.LopsException) UAggOuterChain(org.apache.sysml.lops.UAggOuterChain) UnaryCP(org.apache.sysml.lops.UnaryCP) OperationTypes(org.apache.sysml.lops.Aggregate.OperationTypes) DirectionTypes(org.apache.sysml.lops.PartialAggregate.DirectionTypes) ExecType(org.apache.sysml.lops.LopProperties.ExecType) PartialAggregate(org.apache.sysml.lops.PartialAggregate) TernaryAggregate(org.apache.sysml.lops.TernaryAggregate) Aggregate(org.apache.sysml.lops.Aggregate)

Example 2 with PartialAggregate

use of org.apache.sysml.lops.PartialAggregate in project incubator-systemml by apache.

the class UnaryOp method constructLopsIQM.

private Lop constructLopsIQM() throws HopsException, LopsException {
    ExecType et = optFindExecType();
    Hop input = getInput().get(0);
    if (et == ExecType.MR) {
        CombineUnary combine = CombineUnary.constructCombineLop(input.constructLops(), DataType.MATRIX, getValueType());
        combine.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz());
        SortKeys sort = SortKeys.constructSortByValueLop(combine, SortKeys.OperationTypes.WithoutWeights, DataType.MATRIX, ValueType.DOUBLE, ExecType.MR);
        // Sort dimensions are same as the first input
        sort.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz());
        Data lit = Data.createLiteralLop(ValueType.DOUBLE, Double.toString(0.25));
        lit.setAllPositions(this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
        PickByCount pick = new PickByCount(sort, lit, DataType.MATRIX, getValueType(), PickByCount.OperationTypes.RANGEPICK);
        pick.getOutputParameters().setDimensions(-1, -1, getRowsInBlock(), getColsInBlock(), -1);
        setLineNumbers(pick);
        PartialAggregate pagg = new PartialAggregate(pick, HopsAgg2Lops.get(Hop.AggOp.SUM), HopsDirection2Lops.get(Hop.Direction.RowCol), DataType.MATRIX, getValueType());
        setLineNumbers(pagg);
        // Set the dimensions of PartialAggregate LOP based on the
        // direction in which aggregation is performed
        pagg.setDimensionsBasedOnDirection(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock());
        Group group1 = new Group(pagg, Group.OperationTypes.Sort, DataType.MATRIX, getValueType());
        group1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(group1);
        Aggregate agg1 = new Aggregate(group1, HopsAgg2Lops.get(Hop.AggOp.SUM), DataType.MATRIX, getValueType(), ExecType.MR);
        agg1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        agg1.setupCorrectionLocation(pagg.getCorrectionLocation());
        setLineNumbers(agg1);
        UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
        unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
        setLineNumbers(unary1);
        Unary iqm = new Unary(sort, unary1, Unary.OperationTypes.MR_IQM, DataType.SCALAR, ValueType.DOUBLE, ExecType.CP);
        iqm.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
        setLineNumbers(iqm);
        return iqm;
    } else {
        SortKeys sort = SortKeys.constructSortByValueLop(input.constructLops(), SortKeys.OperationTypes.WithoutWeights, DataType.MATRIX, ValueType.DOUBLE, et);
        sort.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz());
        PickByCount pick = new PickByCount(sort, null, getDataType(), getValueType(), PickByCount.OperationTypes.IQM, et, true);
        pick.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(pick);
        return pick;
    }
}
Also used : PartialAggregate(org.apache.sysml.lops.PartialAggregate) CumulativePartialAggregate(org.apache.sysml.lops.CumulativePartialAggregate) SortKeys(org.apache.sysml.lops.SortKeys) Group(org.apache.sysml.lops.Group) PickByCount(org.apache.sysml.lops.PickByCount) CombineUnary(org.apache.sysml.lops.CombineUnary) MultiThreadedHop(org.apache.sysml.hops.Hop.MultiThreadedHop) ExecType(org.apache.sysml.lops.LopProperties.ExecType) Data(org.apache.sysml.lops.Data) PartialAggregate(org.apache.sysml.lops.PartialAggregate) CumulativeSplitAggregate(org.apache.sysml.lops.CumulativeSplitAggregate) Aggregate(org.apache.sysml.lops.Aggregate) CumulativePartialAggregate(org.apache.sysml.lops.CumulativePartialAggregate) CombineUnary(org.apache.sysml.lops.CombineUnary) Unary(org.apache.sysml.lops.Unary) UnaryCP(org.apache.sysml.lops.UnaryCP)

Example 3 with PartialAggregate

use of org.apache.sysml.lops.PartialAggregate in project incubator-systemml by apache.

the class BinaryOp method constructLopsIQM.

private void constructLopsIQM(ExecType et) throws HopsException, LopsException {
    if (et == ExecType.MR) {
        CombineBinary combine = CombineBinary.constructCombineLop(OperationTypes.PreSort, (Lop) getInput().get(0).constructLops(), (Lop) getInput().get(1).constructLops(), DataType.MATRIX, getValueType());
        combine.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getRowsInBlock(), getInput().get(0).getColsInBlock(), getInput().get(0).getNnz());
        SortKeys sort = SortKeys.constructSortByValueLop(combine, SortKeys.OperationTypes.WithWeights, DataType.MATRIX, ValueType.DOUBLE, ExecType.MR);
        // Sort dimensions are same as the first input
        sort.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getRowsInBlock(), getInput().get(0).getColsInBlock(), getInput().get(0).getNnz());
        Data lit = Data.createLiteralLop(ValueType.DOUBLE, Double.toString(0.25));
        setLineNumbers(lit);
        PickByCount pick = new PickByCount(sort, lit, DataType.MATRIX, getValueType(), PickByCount.OperationTypes.RANGEPICK);
        pick.getOutputParameters().setDimensions(-1, -1, getRowsInBlock(), getColsInBlock(), -1);
        setLineNumbers(pick);
        PartialAggregate pagg = new PartialAggregate(pick, HopsAgg2Lops.get(Hop.AggOp.SUM), HopsDirection2Lops.get(Hop.Direction.RowCol), DataType.MATRIX, getValueType());
        setLineNumbers(pagg);
        // Set the dimensions of PartialAggregate LOP based on the
        // direction in which aggregation is performed
        pagg.setDimensionsBasedOnDirection(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock());
        Group group1 = new Group(pagg, Group.OperationTypes.Sort, DataType.MATRIX, getValueType());
        setOutputDimensions(group1);
        setLineNumbers(group1);
        Aggregate agg1 = new Aggregate(group1, HopsAgg2Lops.get(Hop.AggOp.SUM), DataType.MATRIX, getValueType(), ExecType.MR);
        setOutputDimensions(agg1);
        agg1.setupCorrectionLocation(pagg.getCorrectionLocation());
        setLineNumbers(agg1);
        UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), DataType.SCALAR, getValueType());
        unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
        setLineNumbers(unary1);
        Unary iqm = new Unary(sort, unary1, Unary.OperationTypes.MR_IQM, DataType.SCALAR, ValueType.DOUBLE, ExecType.CP);
        iqm.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
        setLineNumbers(iqm);
        setLops(iqm);
    } else {
        SortKeys sort = SortKeys.constructSortByValueLop(getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, getInput().get(0).getDataType(), getInput().get(0).getValueType(), et);
        sort.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(), getInput().get(0).getRowsInBlock(), getInput().get(0).getColsInBlock(), getInput().get(0).getNnz());
        PickByCount pick = new PickByCount(sort, null, getDataType(), getValueType(), PickByCount.OperationTypes.IQM, et, true);
        setOutputDimensions(pick);
        setLineNumbers(pick);
        setLops(pick);
    }
}
Also used : PartialAggregate(org.apache.sysml.lops.PartialAggregate) CombineBinary(org.apache.sysml.lops.CombineBinary) SortKeys(org.apache.sysml.lops.SortKeys) Group(org.apache.sysml.lops.Group) PickByCount(org.apache.sysml.lops.PickByCount) Data(org.apache.sysml.lops.Data) PartialAggregate(org.apache.sysml.lops.PartialAggregate) Aggregate(org.apache.sysml.lops.Aggregate) Unary(org.apache.sysml.lops.Unary) CombineUnary(org.apache.sysml.lops.CombineUnary) UnaryCP(org.apache.sysml.lops.UnaryCP)

Aggregations

Aggregate (org.apache.sysml.lops.Aggregate)3 Group (org.apache.sysml.lops.Group)3 PartialAggregate (org.apache.sysml.lops.PartialAggregate)3 UnaryCP (org.apache.sysml.lops.UnaryCP)3 MultiThreadedHop (org.apache.sysml.hops.Hop.MultiThreadedHop)2 CombineUnary (org.apache.sysml.lops.CombineUnary)2 Data (org.apache.sysml.lops.Data)2 ExecType (org.apache.sysml.lops.LopProperties.ExecType)2 PickByCount (org.apache.sysml.lops.PickByCount)2 SortKeys (org.apache.sysml.lops.SortKeys)2 Unary (org.apache.sysml.lops.Unary)2 SparkAggType (org.apache.sysml.hops.AggBinaryOp.SparkAggType)1 OperationTypes (org.apache.sysml.lops.Aggregate.OperationTypes)1 CombineBinary (org.apache.sysml.lops.CombineBinary)1 CumulativePartialAggregate (org.apache.sysml.lops.CumulativePartialAggregate)1 CumulativeSplitAggregate (org.apache.sysml.lops.CumulativeSplitAggregate)1 Lop (org.apache.sysml.lops.Lop)1 LopsException (org.apache.sysml.lops.LopsException)1 DirectionTypes (org.apache.sysml.lops.PartialAggregate.DirectionTypes)1 TernaryAggregate (org.apache.sysml.lops.TernaryAggregate)1