Search in sources :

Example 1 with WCeMMType

use of org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType in project systemml by apache.

the class QuaternarySPInstruction method parseInstruction.

public static QuaternarySPInstruction parseInstruction(String str) {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    // validity check
    if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
        throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
    }
    // instruction parsing
    if (// wsloss
    WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
        // check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        WeightsType wtype = WeightsType.valueOf(parts[6]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
    } else if (// wumm
    WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
        // check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        String uopcode = parts[1];
        CPOperand in1 = new CPOperand(parts[2]);
        CPOperand in2 = new CPOperand(parts[3]);
        CPOperand in3 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        WUMMType wtype = WUMMType.valueOf(parts[6]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
    } else if (// wdivmm
    WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = opcode.startsWith("red");
        // check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        final WDivMMType wt = WDivMMType.valueOf(parts[6]);
        QuaternaryOperator qop = (wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
        return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
    } else // map/redwsigmoid, map/redwcemm
    {
        boolean isRed = opcode.startsWith("red");
        int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;
        // check number of fields (3 or 4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 7 + addInput4);
        else
            InstructionUtils.checkNumFields(parts, 5 + addInput4);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4 + addInput4]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid"))
            return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
        else if (opcode.endsWith("wcemm")) {
            CPOperand in4 = new CPOperand(parts[4]);
            final WCeMMType wt = WCeMMType.valueOf(parts[6]);
            QuaternaryOperator qop = (wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
        }
    }
    return null;
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) WDivMMType(org.apache.sysml.lops.WeightedDivMM.WDivMMType) WCeMMType(org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType) WeightsType(org.apache.sysml.lops.WeightedSquaredLoss.WeightsType) CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) WUMMType(org.apache.sysml.lops.WeightedUnaryMM.WUMMType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 2 with WCeMMType

use of org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType in project incubator-systemml by apache.

the class QuaternarySPInstruction method parseInstruction.

public static QuaternarySPInstruction parseInstruction(String str) {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    // validity check
    if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
        throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
    }
    // instruction parsing
    if (// wsloss
    WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
        // check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        WeightsType wtype = WeightsType.valueOf(parts[6]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
    } else if (// wumm
    WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
        // check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        String uopcode = parts[1];
        CPOperand in1 = new CPOperand(parts[2]);
        CPOperand in2 = new CPOperand(parts[3]);
        CPOperand in3 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        WUMMType wtype = WUMMType.valueOf(parts[6]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
    } else if (// wdivmm
    WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) {
        boolean isRed = opcode.startsWith("red");
        // check number of fields (4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 8);
        else
            InstructionUtils.checkNumFields(parts, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand out = new CPOperand(parts[5]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
        final WDivMMType wt = WDivMMType.valueOf(parts[6]);
        QuaternaryOperator qop = (wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
        return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
    } else // map/redwsigmoid, map/redwcemm
    {
        boolean isRed = opcode.startsWith("red");
        int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;
        // check number of fields (3 or 4 inputs, output, type)
        if (isRed)
            InstructionUtils.checkNumFields(parts, 7 + addInput4);
        else
            InstructionUtils.checkNumFields(parts, 5 + addInput4);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4 + addInput4]);
        // in mappers always through distcache, in reducers through distcache/shuffle
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid"))
            return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
        else if (opcode.endsWith("wcemm")) {
            CPOperand in4 = new CPOperand(parts[4]);
            final WCeMMType wt = WCeMMType.valueOf(parts[6]);
            QuaternaryOperator qop = (wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt));
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
        }
    }
    return null;
}
Also used : QuaternaryOperator(org.apache.sysml.runtime.matrix.operators.QuaternaryOperator) WDivMMType(org.apache.sysml.lops.WeightedDivMM.WDivMMType) WCeMMType(org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType) WeightsType(org.apache.sysml.lops.WeightedSquaredLoss.WeightsType) CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) WUMMType(org.apache.sysml.lops.WeightedUnaryMM.WUMMType) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

WCeMMType (org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType)2 WDivMMType (org.apache.sysml.lops.WeightedDivMM.WDivMMType)2 WeightsType (org.apache.sysml.lops.WeightedSquaredLoss.WeightsType)2 WUMMType (org.apache.sysml.lops.WeightedUnaryMM.WUMMType)2 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)2 QuaternaryOperator (org.apache.sysml.runtime.matrix.operators.QuaternaryOperator)2