use of org.apache.sysml.lops.MMTSJ in project incubator-systemml by apache.
the class AggBinaryOp method constructSparkLopsTSMM.
//////////////////////////
// Spark Lops generation
/////////////////////////
private void constructSparkLopsTSMM(MMTSJType mmtsj, boolean multiPass) throws HopsException, LopsException {
Hop input = getInput().get(mmtsj.isLeft() ? 1 : 0);
MMTSJ tsmm = new MMTSJ(input.constructLops(), getDataType(), getValueType(), ExecType.SPARK, mmtsj, multiPass);
setOutputDimensions(tsmm);
setLineNumbers(tsmm);
setLops(tsmm);
}
use of org.apache.sysml.lops.MMTSJ in project incubator-systemml by apache.
the class AggBinaryOp method constructMRLopsTSMM.
private void constructMRLopsTSMM(MMTSJType mmtsj) throws HopsException, LopsException {
Hop input = getInput().get(mmtsj.isLeft() ? 1 : 0);
MMTSJ tsmm = new MMTSJ(input.constructLops(), getDataType(), getValueType(), ExecType.MR, mmtsj);
tsmm.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(tsmm);
Aggregate agg1 = new Aggregate(tsmm, HopsAgg2Lops.get(outerOp), getDataType(), getValueType(), ExecType.MR);
agg1.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
// aggregation uses kahanSum but the inputs do not have correction values
agg1.setupCorrectionLocation(CorrectionLocationType.NONE);
setLineNumbers(agg1);
setLops(agg1);
}
use of org.apache.sysml.lops.MMTSJ in project incubator-systemml by apache.
the class AggBinaryOp method constructCPLopsTSMM.
//////////////////////////
// CP Lops generation
/////////////////////////
private void constructCPLopsTSMM(MMTSJType mmtsj) throws HopsException, LopsException {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
ExecType et = ExecType.CP;
if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET)) {
et = ExecType.GPU;
}
Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft() ? 1 : 0).constructLops(), getDataType(), getValueType(), et, mmtsj, false, k);
matmultCP.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(matmultCP);
setLops(matmultCP);
}
Aggregations