use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.
the class InstructionUtils method parseAggregateTernaryOperator.
public static AggregateTernaryOperator parseAggregateTernaryOperator(String opcode, int numThreads) {
CorrectionLocationType corr = opcode.equalsIgnoreCase("tak+*") ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.LASTROW;
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, corr);
IndexFunction ixfun = opcode.equalsIgnoreCase("tak+*") ? ReduceAll.getReduceAllFnObject() : ReduceRow.getReduceRowFnObject();
return new AggregateTernaryOperator(Multiply.getMultiplyFnObject(), agg, ixfun, numThreads);
}
use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.
the class AggregateTernarySPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
SparkExecutionContext sec = (SparkExecutionContext) ec;
//get inputs
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input1.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = //matrix or literal 1
input3.isLiteral() ? //matrix or literal 1
null : sec.getBinaryBlockRDDHandleForVariable(input3.getName());
//execute aggregate ternary operation
AggregateTernaryOperator aggop = (AggregateTernaryOperator) _optr;
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
if (in3 != null) {
//3 inputs
out = in1.join(in2).join(in3).mapToPair(new RDDAggregateTernaryFunction(aggop));
} else {
//2 inputs (third is literal 1)
out = in1.join(in2).mapToPair(new RDDAggregateTernaryFunction2(aggop));
}
//aggregate partial results
if (//tak+*
aggop.indexFn instanceof ReduceAll) {
//aggregate and create output (no lineage because scalar)
MatrixBlock tmp = RDDAggregateUtils.sumStable(out.values());
DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
sec.setVariable(output.getName(), ret);
} else if (//tack+* single block
mcIn.dimsKnown() && mcIn.getCols() <= mcIn.getColsPerBlock()) {
//single block aggregation and drop correction
MatrixBlock ret = RDDAggregateUtils.aggStable(out, aggop.aggOp);
ret.dropLastRowsOrColums(aggop.aggOp.correctionLocation);
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of matrix characteristics
sec.setMatrixOutput(output.getName(), ret);
} else //tack+* multi block
{
//multi-block aggregation and drop correction
out = RDDAggregateUtils.aggByKeyStable(out, aggop.aggOp, false);
out = out.mapValues(new AggregateDropCorrectionFunction(aggop.aggOp));
//put output RDD handle into symbol table
updateUnaryAggOutputMatrixCharacteristics(sec, aggop.indexFn);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
sec.addLineageRDD(output.getName(), input2.getName());
if (in3 != null)
sec.addLineageRDD(output.getName(), input3.getName());
}
}
use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.
the class AggregateTernaryCPInstruction method parseInstruction.
public static AggregateTernaryCPInstruction parseInstruction(String str) throws DMLRuntimeException {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if (opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
InstructionUtils.checkNumFields(parts, 5);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
int numThreads = Integer.parseInt(parts[5]);
AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
return new AggregateTernaryCPInstruction(op, in1, in2, in3, out, opcode, str);
} else {
throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
}
}
use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.
the class AggregateTernaryCPInstruction method processInstruction.
@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
MatrixBlock matBlock3 = //matrix or literal 1
input3.isLiteral() ? //matrix or literal 1
null : ec.getMatrixInput(input3.getName());
AggregateTernaryOperator ab_op = (AggregateTernaryOperator) _optr;
MatrixBlock ret = matBlock1.aggregateTernaryOperations(matBlock1, matBlock2, matBlock3, new MatrixBlock(), ab_op, true);
//release inputs/outputs
ec.releaseMatrixInput(input1.getName());
ec.releaseMatrixInput(input2.getName());
if (!input3.isLiteral())
ec.releaseMatrixInput(input3.getName());
if (output.getDataType().isScalar())
ec.setScalarOutput(output.getName(), new DoubleObject(ret.quickGetValue(0, 0)));
else
ec.setMatrixOutput(output.getName(), ret);
}
use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.
the class AggregateTernarySPInstruction method parseInstruction.
public static AggregateTernarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if (opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode);
return new AggregateTernarySPInstruction(op, in1, in2, in3, out, opcode, str);
} else {
throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
}
}
Aggregations