Search in sources :

Example 11 with AggregateOperationTypes

use of org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes in project systemml by apache.

the class CM_N_COVInstruction method parseInstruction.

public static CM_N_COVInstruction parseInstruction(String str) {
    String[] parts = InstructionUtils.getInstructionParts(str);
    byte in, out;
    int cst;
    String opcode = parts[0];
    if (opcode.equalsIgnoreCase("cm")) {
        in = Byte.parseByte(parts[1]);
        cst = Integer.parseInt(parts[2]);
        out = Byte.parseByte(parts[3]);
        if (cst > 4 || cst < 0 || cst == 1)
            throw new DMLRuntimeException("constant for central moment has to be 0, 2, 3, or 4");
        AggregateOperationTypes opType = CMOperator.getCMAggOpType(cst);
        CMOperator cm = new CMOperator(CM.getCMFnObject(opType), opType);
        return new CM_N_COVInstruction(cm, in, out, str);
    } else if (opcode.equalsIgnoreCase("cov")) {
        in = Byte.parseByte(parts[1]);
        out = Byte.parseByte(parts[2]);
        COVOperator cov = new COVOperator(COV.getCOMFnObject());
        return new CM_N_COVInstruction(cov, in, out, str);
    } else
        throw new DMLRuntimeException("unknown opcode " + opcode);
}
Also used : COVOperator(org.apache.sysml.runtime.matrix.operators.COVOperator) AggregateOperationTypes(org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 12 with AggregateOperationTypes

use of org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes in project systemml by apache.

the class CostEstimatorStaticRuntime method extractMRInstStatistics.

private Object[] extractMRInstStatistics(String inst, VarStats[] stats) {
    // stats, attrs
    Object[] ret = new Object[2];
    VarStats[] vs = new VarStats[3];
    String[] attr = null;
    String[] parts = InstructionUtils.getInstructionParts(inst);
    String opcode = parts[0];
    if (opcode.equals(DataGen.RAND_OPCODE)) {
        vs[0] = _unknownStats;
        vs[1] = _unknownStats;
        vs[2] = stats[Integer.parseInt(parts[2])];
        int type = 2;
        // awareness of instruction patching min/max
        if (!parts[7].contains(Lop.VARIABLE_NAME_PLACEHOLDER) && !parts[8].contains(Lop.VARIABLE_NAME_PLACEHOLDER)) {
            double minValue = Double.parseDouble(parts[7]);
            double maxValue = Double.parseDouble(parts[8]);
            double sparsity = Double.parseDouble(parts[9]);
            if (minValue == 0.0 && maxValue == 0.0)
                type = 0;
            else if (sparsity == 1.0 && minValue == maxValue)
                type = 1;
        }
        attr = new String[] { String.valueOf(type) };
    }
    if (opcode.equals(DataGen.SEQ_OPCODE)) {
        vs[0] = _unknownStats;
        vs[1] = _unknownStats;
        vs[2] = stats[Integer.parseInt(parts[2])];
    } else // general case
    {
        String inst2 = replaceInstructionPatch(inst);
        MRInstruction mrinst = MRInstructionParser.parseSingleInstruction(inst2);
        if (mrinst instanceof UnaryMRInstructionBase) {
            UnaryMRInstructionBase uinst = (UnaryMRInstructionBase) mrinst;
            vs[0] = uinst.input >= 0 ? stats[uinst.input] : _unknownStats;
            vs[1] = _unknownStats;
            vs[2] = stats[uinst.output];
            if (// scalar input, e.g., print
            vs[0] == null)
                vs[0] = _scalarStats;
            if (// scalar output
            vs[2] == null)
                vs[2] = _scalarStats;
            if (mrinst instanceof MMTSJMRInstruction) {
                String type = ((MMTSJMRInstruction) mrinst).getMMTSJType().toString();
                attr = new String[] { type };
            } else if (mrinst instanceof CM_N_COVInstruction) {
                if (opcode.equals("cm"))
                    attr = new String[] { parts[parts.length - 2] };
            } else if (mrinst instanceof GroupedAggregateInstruction) {
                if (opcode.equals("groupedagg")) {
                    AggregateOperationTypes type = CMOperator.getAggOpType(parts[2], parts[3]);
                    attr = new String[] { String.valueOf(type.ordinal()) };
                }
            }
        } else if (mrinst instanceof BinaryMRInstructionBase) {
            BinaryMRInstructionBase binst = (BinaryMRInstructionBase) mrinst;
            vs[0] = stats[binst.input1];
            vs[1] = stats[binst.input2];
            vs[2] = stats[binst.output];
            if (// scalar input,
            vs[0] == null)
                vs[0] = _scalarStats;
            if (// scalar input,
            vs[1] == null)
                vs[1] = _scalarStats;
            if (// scalar output
            vs[2] == null)
                vs[2] = _scalarStats;
            if (opcode.equals("rmempty")) {
                RemoveEmptyMRInstruction rbinst = (RemoveEmptyMRInstruction) mrinst;
                attr = new String[] { rbinst.isRemoveRows() ? "0" : "1" };
            }
        } else if (mrinst instanceof TernaryInstruction) {
            TernaryInstruction tinst = (TernaryInstruction) mrinst;
            byte[] ix = tinst.getAllIndexes();
            for (int i = 0; i < ix.length - 1; i++) vs[0] = stats[ix[i]];
            vs[2] = stats[ix[ix.length - 1]];
            if (// scalar input,
            vs[0] == null)
                vs[0] = _scalarStats;
            if (// scalar input,
            vs[1] == null)
                vs[1] = _scalarStats;
            if (// scalar output
            vs[2] == null)
                vs[2] = _scalarStats;
        } else if (mrinst instanceof CtableInstruction) {
            CtableInstruction tinst = (CtableInstruction) mrinst;
            vs[0] = stats[tinst.input1];
            vs[1] = stats[tinst.input2];
            vs[2] = stats[tinst.input3];
            if (// scalar input,
            vs[0] == null)
                vs[0] = _scalarStats;
            if (// scalar input,
            vs[1] == null)
                vs[1] = _scalarStats;
            if (// scalar input
            vs[2] == null)
                vs[2] = _scalarStats;
        } else if (mrinst instanceof PickByCountInstruction) {
            PickByCountInstruction pinst = (PickByCountInstruction) mrinst;
            vs[0] = stats[pinst.input1];
            vs[2] = stats[pinst.output];
            if (// scalar input,
            vs[0] == null)
                vs[0] = _scalarStats;
            if (// scalar input,
            vs[1] == null)
                vs[1] = _scalarStats;
            if (// scalar input
            vs[2] == null)
                vs[2] = _scalarStats;
        } else if (mrinst instanceof MapMultChainInstruction) {
            MapMultChainInstruction minst = (MapMultChainInstruction) mrinst;
            vs[0] = stats[minst.getInput1()];
            vs[1] = stats[minst.getInput2()];
            if (minst.getInput3() >= 0)
                vs[2] = stats[minst.getInput3()];
            if (// scalar input,
            vs[0] == null)
                vs[0] = _scalarStats;
            if (// scalar input,
            vs[1] == null)
                vs[1] = _scalarStats;
            if (// scalar input
            vs[2] == null)
                vs[2] = _scalarStats;
        }
    }
    // maintain var status (CP output always inmem)
    vs[2]._inmem = true;
    ret[0] = vs;
    ret[1] = attr;
    return ret;
}
Also used : CM_N_COVInstruction(org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction) BinaryMRInstructionBase(org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase) PickByCountInstruction(org.apache.sysml.runtime.instructions.mr.PickByCountInstruction) TernaryInstruction(org.apache.sysml.runtime.instructions.mr.TernaryInstruction) AggregateOperationTypes(org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes) RemoveEmptyMRInstruction(org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction) CtableInstruction(org.apache.sysml.runtime.instructions.mr.CtableInstruction) UnaryMRInstructionBase(org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase) MapMultChainInstruction(org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction) MMTSJMRInstruction(org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction) DataGenMRInstruction(org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction) MMTSJMRInstruction(org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction) MRInstruction(org.apache.sysml.runtime.instructions.mr.MRInstruction) RemoveEmptyMRInstruction(org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction) GroupedAggregateInstruction(org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction)

Aggregations

AggregateOperationTypes (org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes)12 CMOperator (org.apache.sysml.runtime.matrix.operators.CMOperator)8 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)6 HashMap (java.util.HashMap)2 CompressedMatrixBlock (org.apache.sysml.runtime.compress.CompressedMatrixBlock)2 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)2 AggregateTernaryCPInstruction (org.apache.sysml.runtime.instructions.cp.AggregateTernaryCPInstruction)2 AggregateUnaryCPInstruction (org.apache.sysml.runtime.instructions.cp.AggregateUnaryCPInstruction)2 BinaryCPInstruction (org.apache.sysml.runtime.instructions.cp.BinaryCPInstruction)2 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)2 DataGenCPInstruction (org.apache.sysml.runtime.instructions.cp.DataGenCPInstruction)2 MMTSJCPInstruction (org.apache.sysml.runtime.instructions.cp.MMTSJCPInstruction)2 MultiReturnBuiltinCPInstruction (org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction)2 ParameterizedBuiltinCPInstruction (org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction)2 StringInitCPInstruction (org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction)2 UnaryCPInstruction (org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction)2 VariableCPInstruction (org.apache.sysml.runtime.instructions.cp.VariableCPInstruction)2 BinaryMRInstructionBase (org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase)2 CM_N_COVInstruction (org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction)2 CtableInstruction (org.apache.sysml.runtime.instructions.mr.CtableInstruction)2