use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.
the class ExtractGroupNWeights method call.
@Override
public Iterator<Tuple2<MatrixIndexes, WeightedCell>> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg) throws Exception {
MatrixBlock group = arg._2._1._1;
MatrixBlock target = arg._2._1._2;
MatrixBlock weight = arg._2._2;
//sanity check matching block dimensions
if (group.getNumRows() != target.getNumRows() || group.getNumRows() != target.getNumRows()) {
throw new Exception("The blocksize for group/target/weight blocks are mismatched: " + group.getNumRows() + ", " + target.getNumRows() + ", " + weight.getNumRows());
}
//output weighted cells
ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
for (int i = 0; i < group.getNumRows(); i++) {
WeightedCell weightedCell = new WeightedCell();
weightedCell.setValue(target.quickGetValue(i, 0));
weightedCell.setWeight(weight.quickGetValue(i, 0));
long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
if (groupVal < 1) {
throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
MatrixIndexes ix = new MatrixIndexes(groupVal, 1);
groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ix, weightedCell));
}
return groupValuePairs.iterator();
}
use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.
the class GroupedAggMRCombiner method reduce.
@Override
public void reduce(TaggedMatrixIndexes key, Iterator<WeightedCell> values, OutputCollector<TaggedMatrixIndexes, WeightedCell> out, Reporter reporter) throws IOException {
long start = System.currentTimeMillis();
//get aggregate operator
GroupedAggregateInstruction ins = grpaggInstructions.get(key.getTag());
Operator op = ins.getOperator();
boolean isPartialAgg = true;
//combine iterator to single value
try {
if (//everything except sum
op instanceof CMOperator) {
if (((CMOperator) op).isPartialAggregateOperator()) {
cmObj.reset();
CM lcmFn = cmFn.get(key.getTag());
//partial aggregate cm operator
while (values.hasNext()) {
WeightedCell value = values.next();
lcmFn.execute(cmObj, value.getValue(), value.getWeight());
}
outCell.setValue(cmObj.getRequiredPartialResult(op));
outCell.setWeight(cmObj.getWeight());
} else //forward tuples to reducer
{
isPartialAgg = false;
while (values.hasNext()) out.collect(key, values.next());
}
} else if (//sum
op instanceof AggregateOperator) {
AggregateOperator aggop = (AggregateOperator) op;
if (aggop.correctionExists) {
KahanObject buffer = new KahanObject(aggop.initialValue, 0);
KahanPlus.getKahanPlusFnObject();
//partial aggregate with correction
while (values.hasNext()) {
WeightedCell value = values.next();
aggop.increOp.fn.execute(buffer, value.getValue() * value.getWeight());
}
outCell.setValue(buffer._sum);
outCell.setWeight(1);
} else //no correction
{
double v = aggop.initialValue;
//partial aggregate without correction
while (values.hasNext()) {
WeightedCell value = values.next();
v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
}
outCell.setValue(v);
outCell.setWeight(1);
}
} else
throw new IOException("Unsupported operator in instruction: " + ins);
} catch (Exception ex) {
throw new IOException(ex);
}
//collect the output (to reducer)
if (isPartialAgg)
out.collect(key, outCell);
reporter.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis() - start);
}
use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.
the class GroupedAggMRReducer method reduce.
@Override
public void reduce(TaggedMatrixIndexes key, Iterator<WeightedCell> values, OutputCollector<MatrixIndexes, MatrixCell> out, Reporter report) throws IOException {
commonSetup(report);
//get operator
GroupedAggregateInstruction ins = grpaggInstructions.get(key.getTag());
Operator op = ins.getOperator();
try {
if (//all, but sum
op instanceof CMOperator) {
cmObj.reset();
CM lcmFn = cmFn.get(key.getTag());
while (values.hasNext()) {
WeightedCell value = values.next();
lcmFn.execute(cmObj, value.getValue(), value.getWeight());
}
outCell.setValue(cmObj.getRequiredResult(op));
} else if (//sum
op instanceof AggregateOperator) {
AggregateOperator aggop = (AggregateOperator) op;
if (aggop.correctionExists) {
KahanObject buffer = new KahanObject(aggop.initialValue, 0);
while (values.hasNext()) {
WeightedCell value = values.next();
aggop.increOp.fn.execute(buffer, value.getValue() * value.getWeight());
}
outCell.setValue(buffer._sum);
} else {
double v = aggop.initialValue;
while (values.hasNext()) {
WeightedCell value = values.next();
v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
}
outCell.setValue(v);
}
} else
throw new IOException("Unsupported operator in instruction: " + ins);
} catch (Exception ex) {
throw new IOException(ex);
}
outIndex.setIndexes(key.getBaseObject());
cachedValues.reset();
cachedValues.set(key.getTag(), outIndex, outCell);
processReducerInstructions();
//output the final result matrices
outputResultsFromCachedValues(report);
}
use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.
the class ExtractGroup method execute.
protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes ix, MatrixBlock group, MatrixBlock target) throws Exception {
//sanity check matching block dimensions
if (group.getNumRows() != target.getNumRows()) {
throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows() + " != " + target.getNumRows());
}
//output weighted cells
ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
long coloff = (ix.getColumnIndex() - 1) * _bclen;
//local pre-aggregation for sum w/ known output dimensions
if (_op instanceof AggregateOperator && _ngroups > 0 && OptimizerUtils.isValidCPDimensions(_ngroups, target.getNumColumns())) {
MatrixBlock tmp = group.groupedAggOperations(target, null, new MatrixBlock(), (int) _ngroups, _op);
for (int i = 0; i < tmp.getNumRows(); i++) {
for (int j = 0; j < tmp.getNumColumns(); j++) {
double tmpval = tmp.quickGetValue(i, j);
if (tmpval != 0) {
WeightedCell weightedCell = new WeightedCell();
weightedCell.setValue(tmpval);
weightedCell.setWeight(1);
MatrixIndexes ixout = new MatrixIndexes(i + 1, coloff + j + 1);
groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ixout, weightedCell));
}
}
}
} else //general case without pre-aggregation
{
for (int i = 0; i < group.getNumRows(); i++) {
long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
if (groupVal < 1) {
throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
for (int j = 0; j < target.getNumColumns(); j++) {
WeightedCell weightedCell = new WeightedCell();
weightedCell.setValue(target.quickGetValue(i, j));
weightedCell.setWeight(1);
MatrixIndexes ixout = new MatrixIndexes(groupVal, coloff + j + 1);
groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ixout, weightedCell));
}
}
}
return groupValuePairs;
}
use of org.apache.sysml.runtime.matrix.data.WeightedCell in project incubator-systemml by apache.
the class PerformGroupByAggInCombiner method call.
@Override
public WeightedCell call(WeightedCell value1, WeightedCell value2) throws Exception {
WeightedCell outCell = new WeightedCell();
CM_COV_Object cmObj = new CM_COV_Object();
if (//everything except sum
_op instanceof CMOperator) {
if (((CMOperator) _op).isPartialAggregateOperator()) {
cmObj.reset();
// cmFn.get(key.getTag());
CM lcmFn = CM.getCMFnObject(((CMOperator) _op).aggOpType);
//partial aggregate cm operator
lcmFn.execute(cmObj, value1.getValue(), value1.getWeight());
lcmFn.execute(cmObj, value2.getValue(), value2.getWeight());
outCell.setValue(cmObj.getRequiredPartialResult(_op));
outCell.setWeight(cmObj.getWeight());
} else //forward tuples to reducer
{
throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInReducer");
}
} else if (//sum
_op instanceof AggregateOperator) {
AggregateOperator aggop = (AggregateOperator) _op;
if (aggop.correctionExists) {
KahanObject buffer = new KahanObject(aggop.initialValue, 0);
KahanPlus.getKahanPlusFnObject();
//partial aggregate with correction
aggop.increOp.fn.execute(buffer, value1.getValue() * value1.getWeight());
aggop.increOp.fn.execute(buffer, value2.getValue() * value2.getWeight());
outCell.setValue(buffer._sum);
outCell.setWeight(1);
} else //no correction
{
double v = aggop.initialValue;
//partial aggregate without correction
v = aggop.increOp.fn.execute(v, value1.getValue() * value1.getWeight());
v = aggop.increOp.fn.execute(v, value2.getValue() * value2.getWeight());
outCell.setValue(v);
outCell.setWeight(1);
}
} else
throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + _op);
return outCell;
}
Aggregations