Search in sources :

Example 1 with ChainType

use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.

the class MapmmChainSPInstruction method parseInstruction.

public static MapmmChainSPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    InstructionUtils.checkNumFields(parts, 4, 5);
    String opcode = parts[0];
    //check supported opcode 
    if (!opcode.equalsIgnoreCase(MapMultChain.OPCODE)) {
        throw new DMLRuntimeException("MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
    //parse instruction parts (without exec type)
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand in2 = new CPOperand(parts[2]);
    if (parts.length == 5) {
        CPOperand out = new CPOperand(parts[3]);
        ChainType type = ChainType.valueOf(parts[4]);
        return new MapmmChainSPInstruction(null, in1, in2, out, type, opcode, str);
    } else //parts.length==6
    {
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        ChainType type = ChainType.valueOf(parts[5]);
        return new MapmmChainSPInstruction(null, in1, in2, in3, out, type, opcode, str);
    }
}
Also used : ChainType(org.apache.sysml.lops.MapMultChain.ChainType) CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 2 with ChainType

use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.

the class MapMultChainInstruction method parseInstruction.

public static MapMultChainInstruction parseInstruction(String str) throws DMLRuntimeException {
    //check number of fields (2/3 inputs, output, type)
    InstructionUtils.checkNumFields(str, 4, 5);
    //parse instruction parts (without exec type)
    String[] parts = InstructionUtils.getInstructionParts(str);
    byte in1 = Byte.parseByte(parts[1]);
    byte in2 = Byte.parseByte(parts[2]);
    if (parts.length == 5) {
        byte out = Byte.parseByte(parts[3]);
        ChainType type = ChainType.valueOf(parts[4]);
        return new MapMultChainInstruction(type, in1, in2, out, str);
    } else //parts.length==6
    {
        byte in3 = Byte.parseByte(parts[3]);
        byte out = Byte.parseByte(parts[4]);
        ChainType type = ChainType.valueOf(parts[5]);
        return new MapMultChainInstruction(type, in1, in2, in3, out, str);
    }
}
Also used : ChainType(org.apache.sysml.lops.MapMultChain.ChainType)

Example 3 with ChainType

use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.

the class AggBinaryOp method constructLops.

/**
	 * NOTE: overestimated mem in case of transpose-identity matmult, but 3/2 at worst
	 *       and existing mem estimate advantageous in terms of consistency hops/lops,
	 *       and some special cases internally materialize the transpose for better cache locality  
	 */
@Override
public Lop constructLops() throws HopsException, LopsException {
    //return already created lops
    if (getLops() != null)
        return getLops();
    //construct matrix mult lops (currently only supported aggbinary)
    if (isMatrixMultiply()) {
        Hop input1 = getInput().get(0);
        Hop input2 = getInput().get(1);
        //matrix mult operation selection part 1 (CP vs MR vs Spark)
        ExecType et = optFindExecType();
        //matrix mult operation selection part 2 (specific pattern)
        //determine tsmm pattern
        MMTSJType mmtsj = checkTransposeSelf();
        //determine mmchain pattern
        ChainType chain = checkMapMultChain();
        if (et == ExecType.CP) {
            //matrix mult operation selection part 3 (CP type)
            _method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
            //dispatch CP lops construction 
            switch(_method) {
                case TSMM:
                    constructCPLopsTSMM(mmtsj);
                    break;
                case MAPMM_CHAIN:
                    constructCPLopsMMChain(chain);
                    break;
                case PMM:
                    constructCPLopsPMM();
                    break;
                case MM:
                    constructCPLopsMM();
                    break;
                default:
                    throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing CP lops.");
            }
        } else if (et == ExecType.SPARK) {
            //matrix mult operation selection part 3 (SPARK type)
            boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1);
            _method = optFindMMultMethodSpark(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput, tmmRewrite);
            //dispatch SPARK lops construction 
            switch(_method) {
                case TSMM:
                case TSMM2:
                    constructSparkLopsTSMM(mmtsj, _method == MMultMethod.TSMM2);
                    break;
                case MAPMM_L:
                case MAPMM_R:
                    constructSparkLopsMapMM(_method);
                    break;
                case MAPMM_CHAIN:
                    constructSparkLopsMapMMChain(chain);
                    break;
                case PMAPMM:
                    constructSparkLopsPMapMM();
                    break;
                case CPMM:
                    constructSparkLopsCPMM();
                    break;
                case RMM:
                    constructSparkLopsRMM();
                    break;
                case PMM:
                    constructSparkLopsPMM();
                    break;
                case ZIPMM:
                    constructSparkLopsZIPMM();
                    break;
                default:
                    throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
            }
        } else if (et == ExecType.MR) {
            //matrix mult operation selection part 3 (MR type)
            _method = optFindMMultMethodMR(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput);
            //dispatch MR lops construction
            switch(_method) {
                case MAPMM_L:
                case MAPMM_R:
                    constructMRLopsMapMM(_method);
                    break;
                case MAPMM_CHAIN:
                    constructMRLopsMapMMChain(chain);
                    break;
                case CPMM:
                    constructMRLopsCPMM();
                    break;
                case RMM:
                    constructMRLopsRMM();
                    break;
                case TSMM:
                    constructMRLopsTSMM(mmtsj);
                    break;
                case PMM:
                    constructMRLopsPMM();
                    break;
                default:
                    throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing MR lops.");
            }
        }
    } else
        throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops.");
    //add reblock/checkpoint lops if necessary
    constructAndSetLopsDataFlowProperties();
    return getLops();
}
Also used : MultiThreadedHop(org.apache.sysml.hops.Hop.MultiThreadedHop) ChainType(org.apache.sysml.lops.MapMultChain.ChainType) ExecType(org.apache.sysml.lops.LopProperties.ExecType) MMTSJType(org.apache.sysml.lops.MMTSJ.MMTSJType)

Example 4 with ChainType

use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.

the class AggBinaryOp method checkMapMultChain.

/**
	 * MapMultChain: Determine if XtwXv/XtXv pattern applies for this aggbinary 
	 * and if yes which type. 
	 * 
	 * @return ChainType
	 */
public ChainType checkMapMultChain() {
    ChainType chainType = ChainType.NONE;
    Hop in1 = getInput().get(0);
    Hop in2 = getInput().get(1);
    //check for transpose left input (both chain types)
    if (HopRewriteUtils.isTransposeOperation(in1)) {
        Hop X = in1.getInput().get(0);
        //t(X)%*%(w*(X%*%v))
        if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MULT) {
            Hop in3b = in2.getInput().get(1);
            if (in3b instanceof AggBinaryOp) {
                Hop in4 = in3b.getInput().get(0);
                if (//common input
                X == in4)
                    chainType = ChainType.XtwXv;
            }
        } else //t(X)%*%((X%*%v)-y)
        if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp() == OpOp2.MINUS) {
            Hop in3a = in2.getInput().get(0);
            Hop in3b = in2.getInput().get(1);
            if (in3a instanceof AggBinaryOp && in3b.getDataType() == DataType.MATRIX) {
                Hop in4 = in3a.getInput().get(0);
                if (//common input
                X == in4)
                    chainType = ChainType.XtXvy;
            }
        } else //t(X)%*%(X%*%v)
        if (in2 instanceof AggBinaryOp) {
            Hop in3 = in2.getInput().get(0);
            if (//common input
            X == in3)
                chainType = ChainType.XtXv;
        }
    }
    return chainType;
}
Also used : MultiThreadedHop(org.apache.sysml.hops.Hop.MultiThreadedHop) ChainType(org.apache.sysml.lops.MapMultChain.ChainType)

Example 5 with ChainType

use of org.apache.sysml.lops.MapMultChain.ChainType in project incubator-systemml by apache.

the class MMChainCPInstruction method parseInstruction.

public static MMChainCPInstruction parseInstruction(String str) throws DMLRuntimeException {
    //parse instruction parts (without exec type)
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    InstructionUtils.checkNumFields(parts, 5, 6);
    String opcode = parts[0];
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand in2 = new CPOperand(parts[2]);
    if (parts.length == 6) {
        CPOperand out = new CPOperand(parts[3]);
        ChainType type = ChainType.valueOf(parts[4]);
        int k = Integer.parseInt(parts[5]);
        return new MMChainCPInstruction(null, in1, in2, null, out, type, k, opcode, str);
    } else //parts.length==7
    {
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        ChainType type = ChainType.valueOf(parts[5]);
        int k = Integer.parseInt(parts[6]);
        return new MMChainCPInstruction(null, in1, in2, in3, out, type, k, opcode, str);
    }
}
Also used : ChainType(org.apache.sysml.lops.MapMultChain.ChainType)

Aggregations

ChainType (org.apache.sysml.lops.MapMultChain.ChainType)5 MultiThreadedHop (org.apache.sysml.hops.Hop.MultiThreadedHop)2 ExecType (org.apache.sysml.lops.LopProperties.ExecType)1 MMTSJType (org.apache.sysml.lops.MMTSJ.MMTSJType)1 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)1 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)1