Search in sources :

Example 1 with AggregateTernaryOperator

use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.

the class InstructionUtils method parseAggregateTernaryOperator.

public static AggregateTernaryOperator parseAggregateTernaryOperator(String opcode, int numThreads) {
    CorrectionLocationType corr = opcode.equalsIgnoreCase("tak+*") ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.LASTROW;
    AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, corr);
    IndexFunction ixfun = opcode.equalsIgnoreCase("tak+*") ? ReduceAll.getReduceAllFnObject() : ReduceRow.getReduceRowFnObject();
    return new AggregateTernaryOperator(Multiply.getMultiplyFnObject(), agg, ixfun, numThreads);
}
Also used : IndexFunction(org.apache.sysml.runtime.functionobjects.IndexFunction) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) AggregateTernaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator) CorrectionLocationType(org.apache.sysml.lops.PartialAggregate.CorrectionLocationType)

Example 2 with AggregateTernaryOperator

use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.

the class AggregateTernarySPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    //get inputs
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = //matrix or literal 1
    input3.isLiteral() ? //matrix or literal 1
    null : sec.getBinaryBlockRDDHandleForVariable(input3.getName());
    //execute aggregate ternary operation
    AggregateTernaryOperator aggop = (AggregateTernaryOperator) _optr;
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    if (in3 != null) {
        //3 inputs
        out = in1.join(in2).join(in3).mapToPair(new RDDAggregateTernaryFunction(aggop));
    } else {
        //2 inputs (third is literal 1)
        out = in1.join(in2).mapToPair(new RDDAggregateTernaryFunction2(aggop));
    }
    //aggregate partial results
    if (//tak+*
    aggop.indexFn instanceof ReduceAll) {
        //aggregate and create output (no lineage because scalar)	   
        MatrixBlock tmp = RDDAggregateUtils.sumStable(out.values());
        DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
        sec.setVariable(output.getName(), ret);
    } else if (//tack+* single block
    mcIn.dimsKnown() && mcIn.getCols() <= mcIn.getColsPerBlock()) {
        //single block aggregation and drop correction
        MatrixBlock ret = RDDAggregateUtils.aggStable(out, aggop.aggOp);
        ret.dropLastRowsOrColums(aggop.aggOp.correctionLocation);
        //put output block into symbol table (no lineage because single block)
        //this also includes implicit maintenance of matrix characteristics
        sec.setMatrixOutput(output.getName(), ret);
    } else //tack+* multi block
    {
        //multi-block aggregation and drop correction
        out = RDDAggregateUtils.aggByKeyStable(out, aggop.aggOp, false);
        out = out.mapValues(new AggregateDropCorrectionFunction(aggop.aggOp));
        //put output RDD handle into symbol table
        updateUnaryAggOutputMatrixCharacteristics(sec, aggop.indexFn);
        sec.setRDDHandleForVariable(output.getName(), out);
        sec.addLineageRDD(output.getName(), input1.getName());
        sec.addLineageRDD(output.getName(), input2.getName());
        if (in3 != null)
            sec.addLineageRDD(output.getName(), input3.getName());
    }
}
Also used : ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) AggregateDropCorrectionFunction(org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction) AggregateTernaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics)

Example 3 with AggregateTernaryOperator

use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.

the class AggregateTernaryCPInstruction method parseInstruction.

public static AggregateTernaryCPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    if (opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
        InstructionUtils.checkNumFields(parts, 5);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        int numThreads = Integer.parseInt(parts[5]);
        AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
        return new AggregateTernaryCPInstruction(op, in1, in2, in3, out, opcode, str);
    } else {
        throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
}
Also used : AggregateTernaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 4 with AggregateTernaryOperator

use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.

the class AggregateTernaryCPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
    MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
    MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
    MatrixBlock matBlock3 = //matrix or literal 1
    input3.isLiteral() ? //matrix or literal 1
    null : ec.getMatrixInput(input3.getName());
    AggregateTernaryOperator ab_op = (AggregateTernaryOperator) _optr;
    MatrixBlock ret = matBlock1.aggregateTernaryOperations(matBlock1, matBlock2, matBlock3, new MatrixBlock(), ab_op, true);
    //release inputs/outputs
    ec.releaseMatrixInput(input1.getName());
    ec.releaseMatrixInput(input2.getName());
    if (!input3.isLiteral())
        ec.releaseMatrixInput(input3.getName());
    if (output.getDataType().isScalar())
        ec.setScalarOutput(output.getName(), new DoubleObject(ret.quickGetValue(0, 0)));
    else
        ec.setMatrixOutput(output.getName(), ret);
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) AggregateTernaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator)

Example 5 with AggregateTernaryOperator

use of org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator in project incubator-systemml by apache.

the class AggregateTernarySPInstruction method parseInstruction.

public static AggregateTernarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];
    if (opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
        InstructionUtils.checkNumFields(parts, 4);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4]);
        AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode);
        return new AggregateTernarySPInstruction(op, in1, in2, in3, out, opcode, str);
    } else {
        throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
}
Also used : CPOperand(org.apache.sysml.runtime.instructions.cp.CPOperand) AggregateTernaryOperator(org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

AggregateTernaryOperator (org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator)5 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)2 CorrectionLocationType (org.apache.sysml.lops.PartialAggregate.CorrectionLocationType)1 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)1 IndexFunction (org.apache.sysml.runtime.functionobjects.IndexFunction)1 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)1 CPOperand (org.apache.sysml.runtime.instructions.cp.CPOperand)1 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)1 AggregateDropCorrectionFunction (org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction)1 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)1 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)1 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)1