use of org.apache.sysml.runtime.matrix.data.MatrixValue in project systemml by apache.
the class MMCJMRCombinerReducerBase method performAggregateInstructions.
protected MatrixValue performAggregateInstructions(TaggedFirstSecondIndexes indexes, Iterator<MatrixValue> values) throws IOException {
// manipulation on the tags first
byte realTag = indexes.getTag();
byte representTag;
if (realTag == tagForLeft)
representTag = aggBinInstruction.input1;
else
representTag = aggBinInstruction.input2;
ArrayList<AggregateInstruction> instructions = agg_instructions.get(representTag);
AggregateInstruction ins;
if (instructions == null) {
defaultAggIns.input = realTag;
defaultAggIns.output = realTag;
ins = defaultAggIns;
} else {
if (instructions.size() > 1)
throw new IOException("only one aggregate operation on input " + indexes.getTag() + " is allowed in BlockMMCJMR");
ins = instructions.get(0);
if (ins.input != ins.output)
throw new IOException("input index and output index have to be " + "the same for aggregate instructions in BlockMMCJMR");
}
// performa aggregation before doing mmcj
// TODO: customize the code, since aggregation for matrix multiplcation can only be sum
boolean needStartAgg = true;
try {
while (values.hasNext()) {
MatrixValue value = values.next();
if (needStartAgg) {
buffer.reset(value.getNumRows(), value.getNumColumns(), value.isInSparseFormat());
needStartAgg = false;
// LOG.info("initialize buffer: sparse="+buffer.isInSparseFormat()+", nonZero="+buffer.getNonZeros());
}
buffer.binaryOperationsInPlace(((AggregateOperator) ins.getOperator()).increOp, value);
// LOG.info("increment buffer: sparse="+buffer.isInSparseFormat()+", nonZero="+buffer.getNonZeros());
}
} catch (Exception e) {
throw new IOException(e);
}
if (needStartAgg)
return null;
else
return buffer;
}
use of org.apache.sysml.runtime.matrix.data.MatrixValue in project systemml by apache.
the class MMCJMRReducerWithAggregator method reduce.
@Override
public void reduce(TaggedFirstSecondIndexes indexes, Iterator<MatrixValue> values, OutputCollector<Writable, Writable> out, Reporter report) throws IOException {
long start = System.currentTimeMillis();
commonSetup(report);
// perform aggregate (if necessary, only for binary cell)
MatrixValue aggregateValue = null;
if (valueClass == MatrixBlock.class) {
// multiple blocks for same indexes impossible
aggregateValue = values.next();
} else // MatrixCell.class
{
aggregateValue = performAggregateInstructions(indexes, values);
if (aggregateValue == null)
return;
}
int tag = indexes.getTag();
long firstIndex = indexes.getFirstIndex();
long secondIndex = indexes.getSecondIndex();
// for a different k
if (prevFirstIndex != firstIndex) {
cache.resetCache(true);
prevFirstIndex = firstIndex;
} else if (prevTag > tag)
throw new RuntimeException("tag is not ordered correctly: " + prevTag + " > " + tag);
prevTag = tag;
// perform cross-product binagg
processJoin(tag, secondIndex, aggregateValue);
report.incrCounter(Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis() - start);
}
use of org.apache.sysml.runtime.matrix.data.MatrixValue in project systemml by apache.
the class MMCJMRReducerWithAggregator method processJoin.
private void processJoin(int tag, long inIndex, MatrixValue inValue) throws IOException {
try {
if (// for the cached matrix
tag == 0) {
cache.put(inIndex, inValue);
} else // for the probing matrix
{
for (int i = 0; i < cache.getCacheSize(); i++) {
Pair<MatrixIndexes, MatrixValue> tmp = cache.get(i);
if (// left cached
tagForLeft == 0) {
// perform matrix multiplication
indexesbuffer.setIndexes(tmp.getKey().getRowIndex(), inIndex);
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes((MatrixBlock) tmp.getValue(), (MatrixBlock) inValue, (MatrixBlock) valueBuffer, (AggregateBinaryOperator) aggBinInstruction.getOperator());
} else // right cached
{
// perform matrix multiplication
indexesbuffer.setIndexes(inIndex, tmp.getKey().getColumnIndex());
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes((MatrixBlock) inValue, (MatrixBlock) tmp.getValue(), (MatrixBlock) valueBuffer, (AggregateBinaryOperator) aggBinInstruction.getOperator());
}
// aggregate block to output buffer or direct output
if (aggBinInstruction.getMMCJType() == MMCJType.AGG) {
aggregator.aggregateToBuffer(indexesbuffer, valueBuffer, tagForLeft == 0);
} else {
// MMCJType.NO_AGG
collectFinalMultipleOutputs.collectOutput(indexesbuffer, valueBuffer, 0, cachedReporter);
resultsNonZeros[0] += valueBuffer.getNonZeros();
}
}
}
} catch (Exception ex) {
throw new IOException(ex);
}
}
Aggregations