Search in sources :

Example 1 with WeightedCell

use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.

the class ExtractGroupNWeights method call.

@Override
public Iterator<Tuple2<MatrixIndexes, WeightedCell>> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg) throws Exception {
    MatrixBlock group = arg._2._1._1;
    MatrixBlock target = arg._2._1._2;
    MatrixBlock weight = arg._2._2;
    //sanity check matching block dimensions
    if (group.getNumRows() != target.getNumRows() || group.getNumRows() != target.getNumRows()) {
        throw new Exception("The blocksize for group/target/weight blocks are mismatched: " + group.getNumRows() + ", " + target.getNumRows() + ", " + weight.getNumRows());
    }
    //output weighted cells		
    ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
    for (int i = 0; i < group.getNumRows(); i++) {
        WeightedCell weightedCell = new WeightedCell();
        weightedCell.setValue(target.quickGetValue(i, 0));
        weightedCell.setWeight(weight.quickGetValue(i, 0));
        long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
        if (groupVal < 1) {
            throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
        }
        MatrixIndexes ix = new MatrixIndexes(groupVal, 1);
        groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ix, weightedCell));
    }
    return groupValuePairs.iterator();
}
Also used : WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) Tuple2(scala.Tuple2) ArrayList(java.util.ArrayList)

Example 2 with WeightedCell

use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.

the class GroupedAggMRCombiner method reduce.

@Override
public void reduce(TaggedMatrixIndexes key, Iterator<WeightedCell> values, OutputCollector<TaggedMatrixIndexes, WeightedCell> out, Reporter reporter) throws IOException {
    long start = System.currentTimeMillis();
    //get aggregate operator
    GroupedAggregateInstruction ins = grpaggInstructions.get(key.getTag());
    Operator op = ins.getOperator();
    boolean isPartialAgg = true;
    //combine iterator to single value
    try {
        if (//everything except sum
        op instanceof CMOperator) {
            if (((CMOperator) op).isPartialAggregateOperator()) {
                cmObj.reset();
                CM lcmFn = cmFn.get(key.getTag());
                //partial aggregate cm operator 
                while (values.hasNext()) {
                    WeightedCell value = values.next();
                    lcmFn.execute(cmObj, value.getValue(), value.getWeight());
                }
                outCell.setValue(cmObj.getRequiredPartialResult(op));
                outCell.setWeight(cmObj.getWeight());
            } else //forward tuples to reducer
            {
                isPartialAgg = false;
                while (values.hasNext()) out.collect(key, values.next());
            }
        } else if (//sum
        op instanceof AggregateOperator) {
            AggregateOperator aggop = (AggregateOperator) op;
            if (aggop.correctionExists) {
                KahanObject buffer = new KahanObject(aggop.initialValue, 0);
                KahanPlus.getKahanPlusFnObject();
                //partial aggregate with correction
                while (values.hasNext()) {
                    WeightedCell value = values.next();
                    aggop.increOp.fn.execute(buffer, value.getValue() * value.getWeight());
                }
                outCell.setValue(buffer._sum);
                outCell.setWeight(1);
            } else //no correction
            {
                double v = aggop.initialValue;
                //partial aggregate without correction
                while (values.hasNext()) {
                    WeightedCell value = values.next();
                    v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
                }
                outCell.setValue(v);
                outCell.setWeight(1);
            }
        } else
            throw new IOException("Unsupported operator in instruction: " + ins);
    } catch (Exception ex) {
        throw new IOException(ex);
    }
    //collect the output (to reducer)
    if (isPartialAgg)
        out.collect(key, outCell);
    reporter.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis() - start);
}
Also used : CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) Operator(org.apache.sysml.runtime.matrix.operators.Operator) WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) CM(org.apache.sysml.runtime.functionobjects.CM) IOException(java.io.IOException) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) GroupedAggregateInstruction(org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction) IOException(java.io.IOException)

Example 3 with WeightedCell

use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.

the class GroupedAggMRReducer method reduce.

@Override
public void reduce(TaggedMatrixIndexes key, Iterator<WeightedCell> values, OutputCollector<MatrixIndexes, MatrixCell> out, Reporter report) throws IOException {
    commonSetup(report);
    //get operator
    GroupedAggregateInstruction ins = grpaggInstructions.get(key.getTag());
    Operator op = ins.getOperator();
    try {
        if (//all, but sum
        op instanceof CMOperator) {
            cmObj.reset();
            CM lcmFn = cmFn.get(key.getTag());
            while (values.hasNext()) {
                WeightedCell value = values.next();
                lcmFn.execute(cmObj, value.getValue(), value.getWeight());
            }
            outCell.setValue(cmObj.getRequiredResult(op));
        } else if (//sum
        op instanceof AggregateOperator) {
            AggregateOperator aggop = (AggregateOperator) op;
            if (aggop.correctionExists) {
                KahanObject buffer = new KahanObject(aggop.initialValue, 0);
                while (values.hasNext()) {
                    WeightedCell value = values.next();
                    aggop.increOp.fn.execute(buffer, value.getValue() * value.getWeight());
                }
                outCell.setValue(buffer._sum);
            } else {
                double v = aggop.initialValue;
                while (values.hasNext()) {
                    WeightedCell value = values.next();
                    v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
                }
                outCell.setValue(v);
            }
        } else
            throw new IOException("Unsupported operator in instruction: " + ins);
    } catch (Exception ex) {
        throw new IOException(ex);
    }
    outIndex.setIndexes(key.getBaseObject());
    cachedValues.reset();
    cachedValues.set(key.getTag(), outIndex, outCell);
    processReducerInstructions();
    //output the final result matrices
    outputResultsFromCachedValues(report);
}
Also used : CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) Operator(org.apache.sysml.runtime.matrix.operators.Operator) WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) CM(org.apache.sysml.runtime.functionobjects.CM) IOException(java.io.IOException) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) GroupedAggregateInstruction(org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction) IOException(java.io.IOException)

Example 4 with WeightedCell

use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.

the class ExtractGroup method execute.

protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes ix, MatrixBlock group, MatrixBlock target) throws Exception {
    //sanity check matching block dimensions
    if (group.getNumRows() != target.getNumRows()) {
        throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows() + " != " + target.getNumRows());
    }
    //output weighted cells
    ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
    long coloff = (ix.getColumnIndex() - 1) * _bclen;
    //local pre-aggregation for sum w/ known output dimensions
    if (_op instanceof AggregateOperator && _ngroups > 0 && OptimizerUtils.isValidCPDimensions(_ngroups, target.getNumColumns())) {
        MatrixBlock tmp = group.groupedAggOperations(target, null, new MatrixBlock(), (int) _ngroups, _op);
        for (int i = 0; i < tmp.getNumRows(); i++) {
            for (int j = 0; j < tmp.getNumColumns(); j++) {
                double tmpval = tmp.quickGetValue(i, j);
                if (tmpval != 0) {
                    WeightedCell weightedCell = new WeightedCell();
                    weightedCell.setValue(tmpval);
                    weightedCell.setWeight(1);
                    MatrixIndexes ixout = new MatrixIndexes(i + 1, coloff + j + 1);
                    groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ixout, weightedCell));
                }
            }
        }
    } else //general case without pre-aggregation
    {
        for (int i = 0; i < group.getNumRows(); i++) {
            long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
            if (groupVal < 1) {
                throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
            }
            for (int j = 0; j < target.getNumColumns(); j++) {
                WeightedCell weightedCell = new WeightedCell();
                weightedCell.setValue(target.quickGetValue(i, j));
                weightedCell.setWeight(1);
                MatrixIndexes ixout = new MatrixIndexes(groupVal, coloff + j + 1);
                groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ixout, weightedCell));
            }
        }
    }
    return groupValuePairs;
}
Also used : WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) Tuple2(scala.Tuple2) ArrayList(java.util.ArrayList) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator)

Example 5 with WeightedCell

use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.

the class PerformGroupByAggInCombiner method call.

@Override
public WeightedCell call(WeightedCell value1, WeightedCell value2) throws Exception {
    WeightedCell outCell = new WeightedCell();
    CM_COV_Object cmObj = new CM_COV_Object();
    if (//everything except sum
    _op instanceof CMOperator) {
        if (((CMOperator) _op).isPartialAggregateOperator()) {
            cmObj.reset();
            // cmFn.get(key.getTag());
            CM lcmFn = CM.getCMFnObject(((CMOperator) _op).aggOpType);
            //partial aggregate cm operator
            lcmFn.execute(cmObj, value1.getValue(), value1.getWeight());
            lcmFn.execute(cmObj, value2.getValue(), value2.getWeight());
            outCell.setValue(cmObj.getRequiredPartialResult(_op));
            outCell.setWeight(cmObj.getWeight());
        } else //forward tuples to reducer
        {
            throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInReducer");
        }
    } else if (//sum
    _op instanceof AggregateOperator) {
        AggregateOperator aggop = (AggregateOperator) _op;
        if (aggop.correctionExists) {
            KahanObject buffer = new KahanObject(aggop.initialValue, 0);
            KahanPlus.getKahanPlusFnObject();
            //partial aggregate with correction
            aggop.increOp.fn.execute(buffer, value1.getValue() * value1.getWeight());
            aggop.increOp.fn.execute(buffer, value2.getValue() * value2.getWeight());
            outCell.setValue(buffer._sum);
            outCell.setWeight(1);
        } else //no correction
        {
            double v = aggop.initialValue;
            //partial aggregate without correction
            v = aggop.increOp.fn.execute(v, value1.getValue() * value1.getWeight());
            v = aggop.increOp.fn.execute(v, value2.getValue() * value2.getWeight());
            outCell.setValue(v);
            outCell.setWeight(1);
        }
    } else
        throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + _op);
    return outCell;
}
Also used : WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) CM(org.apache.sysml.runtime.functionobjects.CM) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

WeightedCell (org.apache.sysml.runtime.matrix.data.WeightedCell)6 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)5 CM (org.apache.sysml.runtime.functionobjects.CM)4 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)4 CMOperator (org.apache.sysml.runtime.matrix.operators.CMOperator)4 IOException (java.io.IOException)2 ArrayList (java.util.ArrayList)2 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)2 GroupedAggregateInstruction (org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction)2 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)2 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)2 Operator (org.apache.sysml.runtime.matrix.operators.Operator)2 Tuple2 (scala.Tuple2)2