use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.
the class MapmmChainSPInstruction method parseInstruction.
public static MapmmChainSPInstruction parseInstruction(String str) throws DMLRuntimeException {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 4, 5);
String opcode = parts[0];
//check supported opcode
if (!opcode.equalsIgnoreCase(MapMultChain.OPCODE)) {
throw new DMLRuntimeException("MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + opcode);
}
//parse instruction parts (without exec type)
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
if (parts.length == 5) {
CPOperand out = new CPOperand(parts[3]);
ChainType type = ChainType.valueOf(parts[4]);
return new MapmmChainSPInstruction(null, in1, in2, out, type, opcode, str);
} else //parts.length==6
{
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
ChainType type = ChainType.valueOf(parts[5]);
return new MapmmChainSPInstruction(null, in1, in2, in3, out, type, opcode, str);
}
}
use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.
the class MapMultChainInstruction method parseInstruction.
public static MapMultChainInstruction parseInstruction(String str) throws DMLRuntimeException {
//check number of fields (2/3 inputs, output, type)
InstructionUtils.checkNumFields(str, 4, 5);
//parse instruction parts (without exec type)
String[] parts = InstructionUtils.getInstructionParts(str);
byte in1 = Byte.parseByte(parts[1]);
byte in2 = Byte.parseByte(parts[2]);
if (parts.length == 5) {
byte out = Byte.parseByte(parts[3]);
ChainType type = ChainType.valueOf(parts[4]);
return new MapMultChainInstruction(type, in1, in2, out, str);
} else //parts.length==6
{
byte in3 = Byte.parseByte(parts[3]);
byte out = Byte.parseByte(parts[4]);
ChainType type = ChainType.valueOf(parts[5]);
return new MapMultChainInstruction(type, in1, in2, in3, out, str);
}
}
use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.
the class AggBinaryOp method constructLops.
/**
* NOTE: overestimated mem in case of transpose-identity matmult, but 3/2 at worst
* and existing mem estimate advantageous in terms of consistency hops/lops,
* and some special cases internally materialize the transpose for better cache locality
*/
@Override
public Lop constructLops() throws HopsException, LopsException {
//return already created lops
if (getLops() != null)
return getLops();
//construct matrix mult lops (currently only supported aggbinary)
if (isMatrixMultiply()) {
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
//matrix mult operation selection part 1 (CP vs MR vs Spark)
ExecType et = optFindExecType();
//matrix mult operation selection part 2 (specific pattern)
//determine tsmm pattern
MMTSJType mmtsj = checkTransposeSelf();
//determine mmchain pattern
ChainType chain = checkMapMultChain();
if (et == ExecType.CP) {
//matrix mult operation selection part 3 (CP type)
_method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
//dispatch CP lops construction
switch(_method) {
case TSMM:
constructCPLopsTSMM(mmtsj);
break;
case MAPMM_CHAIN:
constructCPLopsMMChain(chain);
break;
case PMM:
constructCPLopsPMM();
break;
case MM:
constructCPLopsMM();
break;
default:
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing CP lops.");
}
} else if (et == ExecType.SPARK) {
//matrix mult operation selection part 3 (SPARK type)
boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1);
_method = optFindMMultMethodSpark(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput, tmmRewrite);
//dispatch SPARK lops construction
switch(_method) {
case TSMM:
case TSMM2:
constructSparkLopsTSMM(mmtsj, _method == MMultMethod.TSMM2);
break;
case MAPMM_L:
case MAPMM_R:
constructSparkLopsMapMM(_method);
break;
case MAPMM_CHAIN:
constructSparkLopsMapMMChain(chain);
break;
case PMAPMM:
constructSparkLopsPMapMM();
break;
case CPMM:
constructSparkLopsCPMM();
break;
case RMM:
constructSparkLopsRMM();
break;
case PMM:
constructSparkLopsPMM();
break;
case ZIPMM:
constructSparkLopsZIPMM();
break;
default:
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
}
} else if (et == ExecType.MR) {
//matrix mult operation selection part 3 (MR type)
_method = optFindMMultMethodMR(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput);
//dispatch MR lops construction
switch(_method) {
case MAPMM_L:
case MAPMM_R:
constructMRLopsMapMM(_method);
break;
case MAPMM_CHAIN:
constructMRLopsMapMMChain(chain);
break;
case CPMM:
constructMRLopsCPMM();
break;
case RMM:
constructMRLopsRMM();
break;
case TSMM:
constructMRLopsTSMM(mmtsj);
break;
case PMM:
constructMRLopsPMM();
break;
default:
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing MR lops.");
}
}
} else
throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops.");
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.
the class AggBinaryOp method checkMapMultChain.
/**
* MapMultChain: Determine if XtwXv/XtXv pattern applies for this aggbinary
* and if yes which type.
*
* @return ChainType
*/
public ChainType checkMapMultChain() {
ChainType chainType = ChainType.NONE;
Hop in1 = getInput().get(0);
Hop in2 = getInput().get(1);
//check for transpose left input (both chain types)
if (HopRewriteUtils.isTransposeOperation(in1)) {
Hop X = in1.getInput().get(0);
//t(X)%*%(w*(X%*%v))
if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MULT) {
Hop in3b = in2.getInput().get(1);
if (in3b instanceof AggBinaryOp) {
Hop in4 = in3b.getInput().get(0);
if (//common input
X == in4)
chainType = ChainType.XtwXv;
}
} else //t(X)%*%((X%*%v)-y)
if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MINUS) {
Hop in3a = in2.getInput().get(0);
Hop in3b = in2.getInput().get(1);
if (in3a instanceof AggBinaryOp && in3b.getDataType() == DataType.MATRIX) {
Hop in4 = in3a.getInput().get(0);
if (//common input
X == in4)
chainType = ChainType.XtXvy;
}
} else //t(X)%*%(X%*%v)
if (in2 instanceof AggBinaryOp) {
Hop in3 = in2.getInput().get(0);
if (//common input
X == in3)
chainType = ChainType.XtXv;
}
}
return chainType;
}
use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.
the class MMChainCPInstruction method parseInstruction.
public static MMChainCPInstruction parseInstruction(String str) throws DMLRuntimeException {
//parse instruction parts (without exec type)
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 5, 6);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
if (parts.length == 6) {
CPOperand out = new CPOperand(parts[3]);
ChainType type = ChainType.valueOf(parts[4]);
int k = Integer.parseInt(parts[5]);
return new MMChainCPInstruction(null, in1, in2, null, out, type, k, opcode, str);
} else //parts.length==7
{
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
ChainType type = ChainType.valueOf(parts[5]);
int k = Integer.parseInt(parts[6]);
return new MMChainCPInstruction(null, in1, in2, in3, out, type, k, opcode, str);
}
}
Aggregations