Search in sources :

Example 1 with MMTSJType

use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.

the class Tsmm2SPInstruction method parseInstruction.

public static Tsmm2SPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    //check supported opcode 
    if (!opcode.equalsIgnoreCase("tsmm2")) {
        throw new DMLRuntimeException("Tsmm2SPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand out = new CPOperand(parts[2]);
    MMTSJType type = MMTSJType.valueOf(parts[3]);
    return new Tsmm2SPInstruction(null, in1, out, type, opcode, str);
}
Also used : CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) MMTSJType(org.apache.sysml.lops.MMTSJ.MMTSJType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 2 with MMTSJType

use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.

the class TsmmSPInstruction method parseInstruction.

public static TsmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    //check supported opcode 
    if (!opcode.equalsIgnoreCase("tsmm")) {
        throw new DMLRuntimeException("TsmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand out = new CPOperand(parts[2]);
    MMTSJType type = MMTSJType.valueOf(parts[3]);
    return new TsmmSPInstruction(null, in1, out, type, opcode, str);
}
Also used : CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) MMTSJType(org.apache.sysml.lops.MMTSJ.MMTSJType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 3 with MMTSJType

use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.

the class MatrixCharacteristics method computeDimension.

public static void computeDimension(HashMap<Byte, MatrixCharacteristics> dims, MRInstruction ins) throws DMLRuntimeException {
    MatrixCharacteristics dimOut = dims.get(ins.output);
    if (dimOut == null) {
        dimOut = new MatrixCharacteristics();
        dims.put(ins.output, dimOut);
    }
    if (ins instanceof ReorgInstruction) {
        ReorgInstruction realIns = (ReorgInstruction) ins;
        reorg(dims.get(realIns.input), (ReorgOperator) realIns.getOperator(), dimOut);
    } else if (ins instanceof AppendInstruction) {
        AppendInstruction realIns = (AppendInstruction) ins;
        MatrixCharacteristics in_dim1 = dims.get(realIns.input1);
        MatrixCharacteristics in_dim2 = dims.get(realIns.input2);
        if (realIns.isCBind())
            dimOut.set(in_dim1.numRows, in_dim1.numColumns + in_dim2.numColumns, in_dim1.numRowsPerBlock, in_dim2.numColumnsPerBlock);
        else
            dimOut.set(in_dim1.numRows + in_dim2.numRows, in_dim1.numColumns, in_dim1.numRowsPerBlock, in_dim2.numColumnsPerBlock);
    } else if (ins instanceof CumulativeAggregateInstruction) {
        AggregateUnaryInstruction realIns = (AggregateUnaryInstruction) ins;
        MatrixCharacteristics in = dims.get(realIns.input);
        dimOut.set((long) Math.ceil((double) in.getRows() / in.getRowsPerBlock()), in.getCols(), in.getRowsPerBlock(), in.getColsPerBlock());
    } else if (ins instanceof AggregateUnaryInstruction) {
        AggregateUnaryInstruction realIns = (AggregateUnaryInstruction) ins;
        aggregateUnary(dims.get(realIns.input), (AggregateUnaryOperator) realIns.getOperator(), dimOut);
    } else if (ins instanceof AggregateBinaryInstruction) {
        AggregateBinaryInstruction realIns = (AggregateBinaryInstruction) ins;
        aggregateBinary(dims.get(realIns.input1), dims.get(realIns.input2), (AggregateBinaryOperator) realIns.getOperator(), dimOut);
    } else if (ins instanceof MapMultChainInstruction) {
        //output size independent of chain type
        MapMultChainInstruction realIns = (MapMultChainInstruction) ins;
        MatrixCharacteristics mc1 = dims.get(realIns.getInput1());
        MatrixCharacteristics mc2 = dims.get(realIns.getInput2());
        dimOut.set(mc1.numColumns, mc2.numColumns, mc1.numRowsPerBlock, mc1.numColumnsPerBlock);
    } else if (ins instanceof QuaternaryInstruction) {
        QuaternaryInstruction realIns = (QuaternaryInstruction) ins;
        MatrixCharacteristics mc1 = dims.get(realIns.getInput1());
        MatrixCharacteristics mc2 = dims.get(realIns.getInput2());
        MatrixCharacteristics mc3 = dims.get(realIns.getInput3());
        realIns.computeMatrixCharacteristics(mc1, mc2, mc3, dimOut);
    } else if (ins instanceof ReblockInstruction) {
        ReblockInstruction realIns = (ReblockInstruction) ins;
        MatrixCharacteristics in_dim = dims.get(realIns.input);
        dimOut.set(in_dim.numRows, in_dim.numColumns, realIns.brlen, realIns.bclen, in_dim.nonZero);
    } else if (ins instanceof MatrixReshapeMRInstruction) {
        MatrixReshapeMRInstruction mrinst = (MatrixReshapeMRInstruction) ins;
        MatrixCharacteristics in_dim = dims.get(mrinst.input);
        dimOut.set(mrinst.getNumRows(), mrinst.getNumColunms(), in_dim.getRowsPerBlock(), in_dim.getColsPerBlock(), in_dim.getNonZeros());
    } else if (ins instanceof RandInstruction || ins instanceof SeqInstruction) {
        DataGenMRInstruction dataIns = (DataGenMRInstruction) ins;
        dimOut.set(dims.get(dataIns.getInput()));
    } else if (ins instanceof ReplicateInstruction) {
        ReplicateInstruction realIns = (ReplicateInstruction) ins;
        realIns.computeOutputDimension(dims.get(realIns.input), dimOut);
    } else if (//before unary
    ins instanceof ParameterizedBuiltinMRInstruction) {
        ParameterizedBuiltinMRInstruction realIns = (ParameterizedBuiltinMRInstruction) ins;
        realIns.computeOutputCharacteristics(dims.get(realIns.input), dimOut);
    } else if (ins instanceof ScalarInstruction || ins instanceof AggregateInstruction || (ins instanceof UnaryInstruction && !(ins instanceof MMTSJMRInstruction)) || ins instanceof ZeroOutInstruction) {
        UnaryMRInstructionBase realIns = (UnaryMRInstructionBase) ins;
        dimOut.set(dims.get(realIns.input));
    } else if (ins instanceof MMTSJMRInstruction) {
        MMTSJMRInstruction mmtsj = (MMTSJMRInstruction) ins;
        MMTSJType tstype = mmtsj.getMMTSJType();
        MatrixCharacteristics mc = dims.get(mmtsj.input);
        dimOut.set(tstype.isLeft() ? mc.numColumns : mc.numRows, tstype.isLeft() ? mc.numColumns : mc.numRows, mc.numRowsPerBlock, mc.numColumnsPerBlock);
    } else if (ins instanceof PMMJMRInstruction) {
        PMMJMRInstruction pmmins = (PMMJMRInstruction) ins;
        MatrixCharacteristics mc = dims.get(pmmins.input2);
        dimOut.set(pmmins.getNumRows(), mc.numColumns, mc.numRowsPerBlock, mc.numColumnsPerBlock);
    } else if (ins instanceof RemoveEmptyMRInstruction) {
        RemoveEmptyMRInstruction realIns = (RemoveEmptyMRInstruction) ins;
        MatrixCharacteristics mc = dims.get(realIns.input1);
        if (realIns.isRemoveRows())
            dimOut.set(realIns.getOutputLen(), mc.getCols(), mc.numRowsPerBlock, mc.numColumnsPerBlock);
        else
            dimOut.set(mc.getRows(), realIns.getOutputLen(), mc.numRowsPerBlock, mc.numColumnsPerBlock);
    } else if (//needs to be checked before binary
    ins instanceof UaggOuterChainInstruction) {
        UaggOuterChainInstruction realIns = (UaggOuterChainInstruction) ins;
        MatrixCharacteristics mc1 = dims.get(realIns.input1);
        MatrixCharacteristics mc2 = dims.get(realIns.input2);
        realIns.computeOutputCharacteristics(mc1, mc2, dimOut);
    } else if (ins instanceof GroupedAggregateMInstruction) {
        GroupedAggregateMInstruction realIns = (GroupedAggregateMInstruction) ins;
        MatrixCharacteristics mc1 = dims.get(realIns.input1);
        realIns.computeOutputCharacteristics(mc1, dimOut);
    } else if (ins instanceof BinaryInstruction || ins instanceof BinaryMInstruction || ins instanceof CombineBinaryInstruction) {
        BinaryMRInstructionBase realIns = (BinaryMRInstructionBase) ins;
        MatrixCharacteristics mc1 = dims.get(realIns.input1);
        MatrixCharacteristics mc2 = dims.get(realIns.input2);
        if (mc1.getRows() > 1 && mc1.getCols() == 1 && mc2.getRows() == 1 && //outer
        mc2.getCols() > 1) {
            dimOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc2.getColsPerBlock());
        } else {
            //default case
            dimOut.set(mc1);
        }
    } else if (ins instanceof CombineTernaryInstruction) {
        TernaryInstruction realIns = (TernaryInstruction) ins;
        dimOut.set(dims.get(realIns.input1));
    } else if (ins instanceof CombineUnaryInstruction) {
        dimOut.set(dims.get(((CombineUnaryInstruction) ins).input));
    } else if (ins instanceof CM_N_COVInstruction || ins instanceof GroupedAggregateInstruction) {
        dimOut.set(1, 1, 1, 1);
    } else if (ins instanceof RangeBasedReIndexInstruction) {
        RangeBasedReIndexInstruction realIns = (RangeBasedReIndexInstruction) ins;
        MatrixCharacteristics dimIn = dims.get(realIns.input);
        realIns.computeOutputCharacteristics(dimIn, dimOut);
    } else if (ins instanceof TernaryInstruction) {
        TernaryInstruction realIns = (TernaryInstruction) ins;
        MatrixCharacteristics in_dim = dims.get(realIns.input1);
        dimOut.set(realIns.getOutputDim1(), realIns.getOutputDim2(), in_dim.numRowsPerBlock, in_dim.numColumnsPerBlock);
    } else {
        /*
			 * if ins is none of the above cases then we assume that dim_out dimensions are unknown
			 */
        dimOut.numRows = -1;
        dimOut.numColumns = -1;
        dimOut.numRowsPerBlock = 1;
        dimOut.numColumnsPerBlock = 1;
    }
}
Also used : TernaryInstruction(org.apache.sysml.runtime.instructions.mr.TernaryInstruction) CombineTernaryInstruction(org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction) CombineTernaryInstruction(org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction) ReblockInstruction(org.apache.sysml.runtime.instructions.mr.ReblockInstruction) DataGenMRInstruction(org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction) AggregateUnaryInstruction(org.apache.sysml.runtime.instructions.mr.AggregateUnaryInstruction) CombineUnaryInstruction(org.apache.sysml.runtime.instructions.mr.CombineUnaryInstruction) UnaryInstruction(org.apache.sysml.runtime.instructions.mr.UnaryInstruction) PMMJMRInstruction(org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction) GroupedAggregateMInstruction(org.apache.sysml.runtime.instructions.mr.GroupedAggregateMInstruction) CombineUnaryInstruction(org.apache.sysml.runtime.instructions.mr.CombineUnaryInstruction) MatrixReshapeMRInstruction(org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction) RemoveEmptyMRInstruction(org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction) AggregateBinaryInstruction(org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction) CombineBinaryInstruction(org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction) BinaryInstruction(org.apache.sysml.runtime.instructions.mr.BinaryInstruction) ZeroOutInstruction(org.apache.sysml.runtime.instructions.mr.ZeroOutInstruction) QuaternaryInstruction(org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction) UnaryMRInstructionBase(org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase) MapMultChainInstruction(org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction) MMTSJType(org.apache.sysml.lops.MMTSJ.MMTSJType) ReplicateInstruction(org.apache.sysml.runtime.instructions.mr.ReplicateInstruction) CumulativeAggregateInstruction(org.apache.sysml.runtime.instructions.mr.CumulativeAggregateInstruction) GroupedAggregateInstruction(org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction) AggregateInstruction(org.apache.sysml.runtime.instructions.mr.AggregateInstruction) CombineBinaryInstruction(org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction) BinaryMRInstructionBase(org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase) CM_N_COVInstruction(org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction) AggregateUnaryInstruction(org.apache.sysml.runtime.instructions.mr.AggregateUnaryInstruction) ParameterizedBuiltinMRInstruction(org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction) SeqInstruction(org.apache.sysml.runtime.instructions.mr.SeqInstruction) RandInstruction(org.apache.sysml.runtime.instructions.mr.RandInstruction) RangeBasedReIndexInstruction(org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction) AppendInstruction(org.apache.sysml.runtime.instructions.mr.AppendInstruction) ScalarInstruction(org.apache.sysml.runtime.instructions.mr.ScalarInstruction) ReorgInstruction(org.apache.sysml.runtime.instructions.mr.ReorgInstruction) CumulativeAggregateInstruction(org.apache.sysml.runtime.instructions.mr.CumulativeAggregateInstruction) AggregateBinaryInstruction(org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction) AggregateUnaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator) UaggOuterChainInstruction(org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction) MMTSJMRInstruction(org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction) GroupedAggregateInstruction(org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction) BinaryMInstruction(org.apache.sysml.runtime.instructions.mr.BinaryMInstruction)

Example 4 with MMTSJType

use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.

the class AggBinaryOp method constructLops.

/**
	 * NOTE: overestimated mem in case of transpose-identity matmult, but 3/2 at worst
	 *       and existing mem estimate advantageous in terms of consistency hops/lops,
	 *       and some special cases internally materialize the transpose for better cache locality  
	 */
@Override
public Lop constructLops() throws HopsException, LopsException {
    //return already created lops
    if (getLops() != null)
        return getLops();
    //construct matrix mult lops (currently only supported aggbinary)
    if (isMatrixMultiply()) {
        Hop input1 = getInput().get(0);
        Hop input2 = getInput().get(1);
        //matrix mult operation selection part 1 (CP vs MR vs Spark)
        ExecType et = optFindExecType();
        //matrix mult operation selection part 2 (specific pattern)
        //determine tsmm pattern
        MMTSJType mmtsj = checkTransposeSelf();
        //determine mmchain pattern
        ChainType chain = checkMapMultChain();
        if (et == ExecType.CP) {
            //matrix mult operation selection part 3 (CP type)
            _method = optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), input2.getDim1(), input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
            //dispatch CP lops construction 
            switch(_method) {
                case TSMM:
                    constructCPLopsTSMM(mmtsj);
                    break;
                case MAPMM_CHAIN:
                    constructCPLopsMMChain(chain);
                    break;
                case PMM:
                    constructCPLopsPMM();
                    break;
                case MM:
                    constructCPLopsMM();
                    break;
                default:
                    throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing CP lops.");
            }
        } else if (et == ExecType.SPARK) {
            //matrix mult operation selection part 3 (SPARK type)
            boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1);
            _method = optFindMMultMethodSpark(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput, tmmRewrite);
            //dispatch SPARK lops construction 
            switch(_method) {
                case TSMM:
                case TSMM2:
                    constructSparkLopsTSMM(mmtsj, _method == MMultMethod.TSMM2);
                    break;
                case MAPMM_L:
                case MAPMM_R:
                    constructSparkLopsMapMM(_method);
                    break;
                case MAPMM_CHAIN:
                    constructSparkLopsMapMMChain(chain);
                    break;
                case PMAPMM:
                    constructSparkLopsPMapMM();
                    break;
                case CPMM:
                    constructSparkLopsCPMM();
                    break;
                case RMM:
                    constructSparkLopsRMM();
                    break;
                case PMM:
                    constructSparkLopsPMM();
                    break;
                case ZIPMM:
                    constructSparkLopsZIPMM();
                    break;
                default:
                    throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
            }
        } else if (et == ExecType.MR) {
            //matrix mult operation selection part 3 (MR type)
            _method = optFindMMultMethodMR(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput);
            //dispatch MR lops construction
            switch(_method) {
                case MAPMM_L:
                case MAPMM_R:
                    constructMRLopsMapMM(_method);
                    break;
                case MAPMM_CHAIN:
                    constructMRLopsMapMMChain(chain);
                    break;
                case CPMM:
                    constructMRLopsCPMM();
                    break;
                case RMM:
                    constructMRLopsRMM();
                    break;
                case TSMM:
                    constructMRLopsTSMM(mmtsj);
                    break;
                case PMM:
                    constructMRLopsPMM();
                    break;
                default:
                    throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing MR lops.");
            }
        }
    } else
        throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ") while constructing lops.");
    //add reblock/checkpoint lops if necessary
    constructAndSetLopsDataFlowProperties();
    return getLops();
}
Also used : MultiThreadedHop(org.apache.sysml.hops.Hop.MultiThreadedHop) ChainType(org.apache.sysml.lops.MapMultChain.ChainType) ExecType(org.apache.sysml.lops.LopProperties.ExecType) MMTSJType(org.apache.sysml.lops.MMTSJ.MMTSJType)

Example 5 with MMTSJType

use of org.apache.sysml.lops.MMTSJ.MMTSJType in project incubator-systemml by apache.

the class MMTSJCPInstruction method parseInstruction.

public static MMTSJCPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    InstructionUtils.checkNumFields(parts, 4);
    String opcode = parts[0];
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand out = new CPOperand(parts[2]);
    MMTSJType titype = MMTSJType.valueOf(parts[3]);
    int k = Integer.parseInt(parts[4]);
    if (!opcode.equalsIgnoreCase("tsmm"))
        throw new DMLRuntimeException("Unknown opcode while parsing an MMTSJCPInstruction: " + str);
    else
        return new MMTSJCPInstruction(new Operator(true), in1, titype, out, k, opcode, str);
}
Also used : Operator(org.apache.sysml.runtime.matrix.operators.Operator) MMTSJType(org.apache.sysml.lops.MMTSJ.MMTSJType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

MMTSJType (org.apache.sysml.lops.MMTSJ.MMTSJType)9 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)5 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)3 MultiThreadedHop (org.apache.sysml.hops.Hop.MultiThreadedHop)2 Operator (org.apache.sysml.runtime.matrix.operators.Operator)2 ExecType (org.apache.sysml.lops.LopProperties.ExecType)1 ChainType (org.apache.sysml.lops.MapMultChain.ChainType)1 AggregateBinaryInstruction (org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction)1 AggregateInstruction (org.apache.sysml.runtime.instructions.mr.AggregateInstruction)1 AggregateUnaryInstruction (org.apache.sysml.runtime.instructions.mr.AggregateUnaryInstruction)1 AppendInstruction (org.apache.sysml.runtime.instructions.mr.AppendInstruction)1 BinaryInstruction (org.apache.sysml.runtime.instructions.mr.BinaryInstruction)1 BinaryMInstruction (org.apache.sysml.runtime.instructions.mr.BinaryMInstruction)1 BinaryMRInstructionBase (org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase)1 CM_N_COVInstruction (org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction)1 CombineBinaryInstruction (org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction)1 CombineTernaryInstruction (org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction)1 CombineUnaryInstruction (org.apache.sysml.runtime.instructions.mr.CombineUnaryInstruction)1 CumulativeAggregateInstruction (org.apache.sysml.runtime.instructions.mr.CumulativeAggregateInstruction)1 DataGenMRInstruction (org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction)1