use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.
the class MMTSJGPUInstruction method parseInstruction.
/**
* parse MMTSJ GPU instruction
* @param str instruction string
* @return MMTSJGPUInstruction object
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static MMTSJGPUInstruction parseInstruction(String str) throws DMLRuntimeException {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 3);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
MMTSJType titype = MMTSJType.valueOf(parts[3]);
if (!opcode.equalsIgnoreCase("tsmm"))
throw new DMLRuntimeException("Unknown opcode while parsing an MMTSJGPUInstruction: " + str);
else
return new MMTSJGPUInstruction(new Operator(true), in1, titype, out, opcode, str);
}
use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.
the class MMTSJMRInstruction method parseInstruction.
public static MMTSJMRInstruction parseInstruction(String str) throws DMLRuntimeException {
InstructionUtils.checkNumFields(str, 3);
String[] parts = InstructionUtils.getInstructionParts(str);
String opcode = parts[0];
byte in = Byte.parseByte(parts[1]);
byte out = Byte.parseByte(parts[2]);
MMTSJType titype = MMTSJType.valueOf(parts[3]);
if (!opcode.equalsIgnoreCase("tsmm"))
throw new DMLRuntimeException("Unknown opcode while parsing an MMTIJMRInstruction: " + str);
else
return new MMTSJMRInstruction(new Operator(true), in, titype, out, str);
}
use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.
the class AggBinaryOp method computeMemEstimate.
@Override
public void computeMemEstimate(MemoTable memo) {
//extension of default compute memory estimate in order to
//account for smaller tsmm memory requirements.
super.computeMemEstimate(memo);
//tsmm left is guaranteed to require only X but not t(X), while
//tsmm right might have additional requirements to transpose X if sparse
//NOTE: as a heuristic this correction is only applied if not a column vector because
//most other vector operations require memory for at least two vectors (we aim for
//consistency in order to prevent anomalies in parfor opt leading to small degree of par)
MMTSJType mmtsj = checkTransposeSelf();
if (mmtsj.isLeft() && getInput().get(1).dimsKnown() && getInput().get(1).getDim2() > 1) {
_memEstimate = _memEstimate - getInput().get(0)._outputMemEstimate;
}
}
use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.
the class AggBinaryOp method checkTransposeSelf.
/**
* TSMM: Determine if XtX pattern applies for this aggbinary and if yes
* which type.
*
* @return MMTSJType
*/
public MMTSJType checkTransposeSelf() {
MMTSJType ret = MMTSJType.NONE;
Hop in1 = getInput().get(0);
Hop in2 = getInput().get(1);
if (HopRewriteUtils.isTransposeOperation(in1) && in1.getInput().get(0) == in2) {
ret = MMTSJType.LEFT;
}
if (HopRewriteUtils.isTransposeOperation(in2) && in2.getInput().get(0) == in1) {
ret = MMTSJType.RIGHT;
}
return ret;
}
Aggregations