use of org.apache.sysml.runtime.matrix.operators.AggregateOperator in project incubator-systemml by apache.
the class ExtractGroup method execute.
protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes ix, MatrixBlock group, MatrixBlock target) throws Exception {
// sanity check matching block dimensions
if (group.getNumRows() != target.getNumRows()) {
throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows() + " != " + target.getNumRows());
}
// output weighted cells
ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<>();
long coloff = (ix.getColumnIndex() - 1) * _bclen;
// local pre-aggregation for sum w/ known output dimensions
if (_op instanceof AggregateOperator && _ngroups > 0 && OptimizerUtils.isValidCPDimensions(_ngroups, target.getNumColumns())) {
MatrixBlock tmp = group.groupedAggOperations(target, null, new MatrixBlock(), (int) _ngroups, _op);
for (int i = 0; i < tmp.getNumRows(); i++) {
for (int j = 0; j < tmp.getNumColumns(); j++) {
double tmpval = tmp.quickGetValue(i, j);
if (tmpval != 0) {
WeightedCell weightedCell = new WeightedCell();
weightedCell.setValue(tmpval);
weightedCell.setWeight(1);
MatrixIndexes ixout = new MatrixIndexes(i + 1, coloff + j + 1);
groupValuePairs.add(new Tuple2<>(ixout, weightedCell));
}
}
}
} else // general case without pre-aggregation
{
for (int i = 0; i < group.getNumRows(); i++) {
long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
if (groupVal < 1) {
throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
for (int j = 0; j < target.getNumColumns(); j++) {
WeightedCell weightedCell = new WeightedCell();
weightedCell.setValue(target.quickGetValue(i, j));
weightedCell.setWeight(1);
MatrixIndexes ixout = new MatrixIndexes(groupVal, coloff + j + 1);
groupValuePairs.add(new Tuple2<>(ixout, weightedCell));
}
}
}
return groupValuePairs;
}
use of org.apache.sysml.runtime.matrix.operators.AggregateOperator in project incubator-systemml by apache.
the class PerformGroupByAggInCombiner method call.
@Override
public WeightedCell call(WeightedCell value1, WeightedCell value2) throws Exception {
WeightedCell outCell = new WeightedCell();
CM_COV_Object cmObj = new CM_COV_Object();
if (// everything except sum
_op instanceof CMOperator) {
if (((CMOperator) _op).isPartialAggregateOperator()) {
cmObj.reset();
// cmFn.get(key.getTag());
CM lcmFn = CM.getCMFnObject(((CMOperator) _op).aggOpType);
// partial aggregate cm operator
lcmFn.execute(cmObj, value1.getValue(), value1.getWeight());
lcmFn.execute(cmObj, value2.getValue(), value2.getWeight());
outCell.setValue(cmObj.getRequiredPartialResult(_op));
outCell.setWeight(cmObj.getWeight());
} else // forward tuples to reducer
{
throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInReducer");
}
} else if (// sum
_op instanceof AggregateOperator) {
AggregateOperator aggop = (AggregateOperator) _op;
if (aggop.correctionExists) {
KahanObject buffer = new KahanObject(aggop.initialValue, 0);
KahanPlus.getKahanPlusFnObject();
// partial aggregate with correction
aggop.increOp.fn.execute(buffer, value1.getValue() * value1.getWeight());
aggop.increOp.fn.execute(buffer, value2.getValue() * value2.getWeight());
outCell.setValue(buffer._sum);
outCell.setWeight(1);
} else // no correction
{
double v = aggop.initialValue;
// partial aggregate without correction
v = aggop.increOp.fn.execute(v, value1.getValue() * value1.getWeight());
v = aggop.increOp.fn.execute(v, value2.getValue() * value2.getWeight());
outCell.setValue(v);
outCell.setWeight(1);
}
} else
throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + _op);
return outCell;
}
use of org.apache.sysml.runtime.matrix.operators.AggregateOperator in project incubator-systemml by apache.
the class PerformGroupByAggInReducer method call.
@Override
public WeightedCell call(Iterable<WeightedCell> kv) throws Exception {
WeightedCell outCell = new WeightedCell();
CM_COV_Object cmObj = new CM_COV_Object();
if (// everything except sum
op instanceof CMOperator) {
cmObj.reset();
// cmFn.get(key.getTag());
CM lcmFn = CM.getCMFnObject(((CMOperator) op).aggOpType);
if (((CMOperator) op).isPartialAggregateOperator()) {
throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInCombiner");
} else // forward tuples to reducer
{
for (WeightedCell value : kv) lcmFn.execute(cmObj, value.getValue(), value.getWeight());
outCell.setValue(cmObj.getRequiredResult(op));
outCell.setWeight(1);
}
} else if (// sum
op instanceof AggregateOperator) {
AggregateOperator aggop = (AggregateOperator) op;
if (aggop.correctionExists) {
KahanObject buffer = new KahanObject(aggop.initialValue, 0);
KahanPlus.getKahanPlusFnObject();
// partial aggregate with correction
for (WeightedCell value : kv) aggop.increOp.fn.execute(buffer, value.getValue() * value.getWeight());
outCell.setValue(buffer._sum);
outCell.setWeight(1);
} else // no correction
{
double v = aggop.initialValue;
// partial aggregate without correction
for (WeightedCell value : kv) v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
outCell.setValue(v);
outCell.setWeight(1);
}
} else
throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + op);
return outCell;
}
use of org.apache.sysml.runtime.matrix.operators.AggregateOperator in project incubator-systemml by apache.
the class MatrixBlock method max.
/**
* Wrapper method for reduceall-max of a matrix.
*
* @return ?
*/
public double max() {
// construct operator
AggregateOperator aop = new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, ReduceAll.getReduceAllFnObject());
// execute operation
MatrixBlock out = new MatrixBlock(1, 1, false);
LibMatrixAgg.aggregateUnaryMatrix(this, out, auop);
return out.quickGetValue(0, 0);
}
use of org.apache.sysml.runtime.matrix.operators.AggregateOperator in project incubator-systemml by apache.
the class InstructionUtils method parseCumulativeAggregateUnaryOperator.
public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(String opcode) {
AggregateUnaryOperator aggun = null;
if ("ucumack+".equals(opcode)) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if ("ucumac*".equals(opcode)) {
AggregateOperator agg = new AggregateOperator(0, Multiply.getMultiplyFnObject(), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if ("ucumacmin".equals(opcode)) {
AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("min"), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
} else if ("ucumacmax".equals(opcode)) {
AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("max"), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
return aggun;
}
Aggregations