use of org.apache.sysml.runtime.matrix.data.TripleIndexes in project incubator-systemml by apache.
the class RmmSPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
SparkExecutionContext sec = (SparkExecutionContext) ec;
//get input rdds
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
//execute Spark RMM instruction
//step 1: prepare join keys (w/ replication), i/j/k
JavaPairRDD<TripleIndexes, MatrixBlock> tmp1 = in1.flatMapToPair(new RmmReplicateFunction(mc2.getCols(), mc2.getColsPerBlock(), true));
JavaPairRDD<TripleIndexes, MatrixBlock> tmp2 = in2.flatMapToPair(new RmmReplicateFunction(mc1.getRows(), mc1.getRowsPerBlock(), false));
//step 2: join prepared datasets, multiply, and aggregate
JavaPairRDD<MatrixIndexes, MatrixBlock> out = //join by result block
tmp1.join(tmp2).mapToPair(//do matrix multiplication
new RmmMultiplyFunction());
//aggregation per result block
out = RDDAggregateUtils.sumByKeyStable(out, false);
//put output block into symbol table (no lineage because single block)
updateBinaryMMOutputMatrixCharacteristics(sec, true);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
sec.addLineageRDD(output.getName(), input2.getName());
}
Aggregations