Search in sources :

Example 1 with WDivMMType

use of org.apache.sysml.lops.WeightedDivMM.WDivMMType in project incubator-systemml by apache.

the class QuaternaryInstruction method parseInstruction.

public static QuaternaryInstruction parseInstruction(String str) throws DMLRuntimeException {
    String opcode = InstructionUtils.getOpCode(str);
    //validity check
    if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
        throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str);
    }
    //instruction parsing
    if (//wsloss
    WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
        //check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(str, 8);
        else
            InstructionUtils.checkNumFields(str, 6);
        //parse instruction parts (without exec type)
        String[] parts = InstructionUtils.getInstructionParts(str);
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        byte in3 = Byte.parseByte(parts[3]);
        byte in4 = Byte.parseByte(parts[4]);
        byte out = Byte.parseByte(parts[5]);
        WeightsType wtype = WeightsType.valueOf(parts[6]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str);
    } else if (//wumm
    WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
        //check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(str, 8);
        else
            InstructionUtils.checkNumFields(str, 6);
        //parse instruction parts (without exec type)
        String[] parts = InstructionUtils.getInstructionParts(str);
        String uopcode = parts[1];
        byte in1 = Byte.parseByte(parts[2]);
        byte in2 = Byte.parseByte(parts[3]);
        byte in3 = Byte.parseByte(parts[4]);
        byte out = Byte.parseByte(parts[5]);
        WUMMType wtype = WUMMType.valueOf(parts[6]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternaryInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, (byte) -1, out, cacheU, cacheV, str);
    } else if (//wdivmm
    WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = opcode.startsWith("red");
        //check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(str, 8);
        else
            InstructionUtils.checkNumFields(str, 6);
        //parse instruction parts (without exec type)
        String[] parts = InstructionUtils.getInstructionParts(str);
        final WDivMMType wtype = WDivMMType.valueOf(parts[6]);
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        byte in3 = Byte.parseByte(parts[3]);
        byte in4 = wtype.hasScalar() ? -1 : Byte.parseByte(parts[4]);
        byte out = Byte.parseByte(parts[5]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str);
    } else //wsigmoid / wcemm
    {
        boolean isRed = opcode.startsWith("red");
        int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;
        //check number of fields (3 or 4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(str, 7 + addInput4);
        else
            InstructionUtils.checkNumFields(str, 5 + addInput4);
        //parse instruction parts (without exec type)
        String[] parts = InstructionUtils.getInstructionParts(str);
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        byte in3 = Byte.parseByte(parts[3]);
        byte out = Byte.parseByte(parts[4 + addInput4]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid"))
            return new QuaternaryInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, (byte) -1, out, cacheU, cacheV, str);
        else if (opcode.endsWith("wcemm"))
            return new QuaternaryInstruction(new QuaternaryOperator(WCeMMType.valueOf(parts[6])), in1, in2, in3, (byte) -1, out, cacheU, cacheV, str);
    }
    return null;
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) WDivMMType(org.apache.sysml.lops.WeightedDivMM.WDivMMType) WeightsType(org.apache.sysml.lops.WeightedSquaredLoss.WeightsType) WUMMType(org.apache.sysml.lops.WeightedUnaryMM.WUMMType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 2 with WDivMMType

use of org.apache.sysml.lops.WeightedDivMM.WDivMMType in project incubator-systemml by apache.

the class QuaternarySPInstruction method parseInstruction.

public static QuaternarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    //validity check
    if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
        throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
    }
    //instruction parsing
    if (//wsloss
    WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
        //check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        WeightsType wtype = WeightsType.valueOf(parts[6]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
    } else if (//wumm
    WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
        //check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        String uopcode = parts[1];
        CPOperand in1 = new CPOperand(parts[2]);
        CPOperand in2 = new CPOperand(parts[3]);
        CPOperand in3 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        WUMMType wtype = WUMMType.valueOf(parts[6]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
    } else if (//wdivmm
    WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = opcode.startsWith("red");
        //check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        final WDivMMType wt = WDivMMType.valueOf(parts[6]);
        QuaternaryOperator qop = (wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
        return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
    } else //map/redwsigmoid, map/redwcemm
    {
        boolean isRed = opcode.startsWith("red");
        int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;
        //check number of fields (3 or 4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 7 + addInput4);
        else
            InstructionUtils.checkNumFields(parts, 5 + addInput4);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4 + addInput4]);
        //in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid"))
            return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
        else if (opcode.endsWith("wcemm")) {
            CPOperand in4 = new CPOperand(parts[4]);
            final WCeMMType wt = WCeMMType.valueOf(parts[6]);
            QuaternaryOperator qop = (wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
        }
    }
    return null;
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) WDivMMType(org.apache.sysml.lops.WeightedDivMM.WDivMMType) WCeMMType(org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType) WeightsType(org.apache.sysml.lops.WeightedSquaredLoss.WeightsType) CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) WUMMType(org.apache.sysml.lops.WeightedUnaryMM.WUMMType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

WDivMMType (org.apache.sysml.lops.WeightedDivMM.WDivMMType)2 WeightsType (org.apache.sysml.lops.WeightedSquaredLoss.WeightsType)2 WUMMType (org.apache.sysml.lops.WeightedUnaryMM.WUMMType)2 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 QuaternaryOperator (org.apache.sysml.runtime.matrix.operators.QuaternaryOperator)2 WCeMMType (org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType)1 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)1