Search in sources :

Example 11 with ReduceAll

use of org.apache.sysml.runtime.functionobjects.ReduceAll in project incubator-systemml by apache.

the class UaggOuterChainSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || _uaggOp.indexFn instanceof ReduceAll || !LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp));
    String rddVar = (rightCached) ? input1.getName() : input2.getName();
    String bcastVar = (rightCached) ? input2.getName() : input1.getName();
    // get rdd input
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
    MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
    boolean noKeyChange = preservesPartitioning(mcIn, _uaggOp.indexFn);
    // execute UAggOuterChain instruction
    JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
    if (LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp)) {
        // create sorted broadcast matrix
        MatrixBlock mb = sec.getMatrixInput(bcastVar, getExtendedOpcode());
        sec.releaseMatrixInput(bcastVar, getExtendedOpcode());
        // prevent lineage tracking
        bcastVar = null;
        double[] vmb = DataConverter.convertToDoubleVector(mb);
        Broadcast<int[]> bvi = null;
        if (_uaggOp.aggOp.increOp.fn instanceof Builtin) {
            int[] vix = LibMatrixOuterAgg.prepareRowIndices(mb.getNumColumns(), vmb, _bOp, _uaggOp);
            bvi = sec.getSparkContext().broadcast(vix);
        } else
            Arrays.sort(vmb);
        Broadcast<double[]> bv = sec.getSparkContext().broadcast(vmb);
        // partitioning-preserving map-to-pair (under constraints)
        out = in1.mapPartitionsToPair(new RDDMapUAggOuterChainFunction(bv, bvi, _bOp, _uaggOp), noKeyChange);
    } else {
        PartitionedBroadcast<MatrixBlock> bv = sec.getBroadcastForVariable(bcastVar);
        // partitioning-preserving map-to-pair (under constraints)
        out = in1.mapPartitionsToPair(new RDDMapGenUAggOuterChainFunction(bv, _uaggOp, _aggOp, _bOp, mcIn), noKeyChange);
    }
    // final aggregation if required
    if (// RC AGG (output is scalar)
    _uaggOp.indexFn instanceof ReduceAll) {
        MatrixBlock tmp = RDDAggregateUtils.aggStable(out, _aggOp);
        // drop correction after aggregation
        tmp.dropLastRowsOrColumns(_aggOp.correctionLocation);
        // put output block into symbol table (no lineage because single block)
        sec.setMatrixOutput(output.getName(), tmp, getExtendedOpcode());
    } else // R/C AGG (output is rdd)
    {
        // put output RDD handle into symbol table
        updateUnaryAggOutputMatrixCharacteristics(sec);
        if (_uaggOp.aggOp.correctionExists)
            out = out.mapValues(new AggregateDropCorrectionFunction(_uaggOp.aggOp));
        sec.setRDDHandleForVariable(output.getName(), out);
        sec.addLineageRDD(output.getName(), rddVar);
        if (bcastVar != null)
            sec.addLineageBroadcast(output.getName(), bcastVar);
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) AggregateDropCorrectionFunction(org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 12 with ReduceAll

use of org.apache.sysml.runtime.functionobjects.ReduceAll in project incubator-systemml by apache.

the class UaggOuterChainSPInstruction method updateUnaryAggOutputMatrixCharacteristics.

protected void updateUnaryAggOutputMatrixCharacteristics(SparkExecutionContext sec) {
    String strInput1Name, strInput2Name;
    if (_uaggOp.indexFn instanceof ReduceCol) {
        strInput1Name = input1.getName();
        strInput2Name = input2.getName();
    } else {
        strInput1Name = input2.getName();
        strInput2Name = input1.getName();
    }
    MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(strInput1Name);
    MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(strInput2Name);
    MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
    if (!mcOut.dimsKnown()) {
        if (!mc1.dimsKnown()) {
            throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + mc1.toString() + " " + mcOut.toString());
        } else {
            // infer statistics from input based on operator
            if (_uaggOp.indexFn instanceof ReduceAll)
                mcOut.set(1, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            else if (_uaggOp.indexFn instanceof ReduceCol)
                mcOut.set(mc1.getRows(), 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            else if (_uaggOp.indexFn instanceof ReduceRow)
                mcOut.set(1, mc2.getCols(), mc1.getRowsPerBlock(), mc2.getColsPerBlock());
        }
    }
}
Also used : ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 13 with ReduceAll

use of org.apache.sysml.runtime.functionobjects.ReduceAll in project incubator-systemml by apache.

the class LibMatrixAgg method getAggType.

private static AggType getAggType(AggregateUnaryOperator op) {
    ValueFunction vfn = op.aggOp.increOp.fn;
    IndexFunction ifn = op.indexFn;
    // (kahan) sum / sum squared / trace (for ReduceDiag)
    if (vfn instanceof KahanFunction && (op.aggOp.correctionLocation == CorrectionLocationType.LASTCOLUMN || op.aggOp.correctionLocation == CorrectionLocationType.LASTROW) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow || ifn instanceof ReduceDiag)) {
        if (vfn instanceof KahanPlus)
            return AggType.KAHAN_SUM;
        else if (vfn instanceof KahanPlusSq)
            return AggType.KAHAN_SUM_SQ;
    }
    // mean
    if (vfn instanceof Mean && (op.aggOp.correctionLocation == CorrectionLocationType.LASTTWOCOLUMNS || op.aggOp.correctionLocation == CorrectionLocationType.LASTTWOROWS) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
        return AggType.MEAN;
    }
    // variance
    if (vfn instanceof CM && ((CM) vfn).getAggOpType() == AggregateOperationTypes.VARIANCE && (op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS || op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS) && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
        return AggType.VAR;
    }
    // prod
    if (vfn instanceof Multiply && ifn instanceof ReduceAll) {
        return AggType.PROD;
    }
    // min / max
    if (vfn instanceof Builtin && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow)) {
        BuiltinCode bfcode = ((Builtin) vfn).bFunc;
        switch(bfcode) {
            case MAX:
                return AggType.MAX;
            case MIN:
                return AggType.MIN;
            case MAXINDEX:
                return AggType.MAX_INDEX;
            case MININDEX:
                return AggType.MIN_INDEX;
            // do nothing
            default:
        }
    }
    return AggType.INVALID;
}
Also used : ValueFunction(org.apache.sysml.runtime.functionobjects.ValueFunction) ReduceCol(org.apache.sysml.runtime.functionobjects.ReduceCol) ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) Mean(org.apache.sysml.runtime.functionobjects.Mean) ReduceDiag(org.apache.sysml.runtime.functionobjects.ReduceDiag) CM(org.apache.sysml.runtime.functionobjects.CM) ReduceRow(org.apache.sysml.runtime.functionobjects.ReduceRow) IndexFunction(org.apache.sysml.runtime.functionobjects.IndexFunction) BuiltinCode(org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode) KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) Multiply(org.apache.sysml.runtime.functionobjects.Multiply) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) KahanPlusSq(org.apache.sysml.runtime.functionobjects.KahanPlusSq) Builtin(org.apache.sysml.runtime.functionobjects.Builtin)

Example 14 with ReduceAll

use of org.apache.sysml.runtime.functionobjects.ReduceAll in project incubator-systemml by apache.

the class LibMatrixAgg method aggregateTernaryGeneric.

private static void aggregateTernaryGeneric(MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, MatrixBlock ret, IndexFunction ixFn, int rl, int ru) {
    // compute block operations
    KahanObject kbuff = new KahanObject(0, 0);
    KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
    // guaranteed to have at least one sparse input, sort by nnz, assume num cells if
    // (potentially incorrect) in dense representation, keep null at end via stable sort
    MatrixBlock[] blocks = new MatrixBlock[] { in1, in2, in3 };
    Arrays.sort(blocks, new Comparator<MatrixBlock>() {

        @Override
        public int compare(MatrixBlock o1, MatrixBlock o2) {
            long nnz1 = (o1 != null && o1.sparse) ? o1.nonZeros : Long.MAX_VALUE;
            long nnz2 = (o2 != null && o2.sparse) ? o2.nonZeros : Long.MAX_VALUE;
            return Long.compare(nnz1, nnz2);
        }
    });
    MatrixBlock lin1 = blocks[0];
    MatrixBlock lin2 = blocks[1];
    MatrixBlock lin3 = blocks[2];
    SparseBlock a = lin1.sparseBlock;
    final int n = in1.clen;
    if (// tak+*
    ixFn instanceof ReduceAll) {
        for (int i = rl; i < ru; i++) if (!a.isEmpty(i)) {
            int apos = a.pos(i);
            int alen = a.size(i);
            int[] aix = a.indexes(i);
            double[] avals = a.values(i);
            for (int j = apos; j < apos + alen; j++) {
                double val1 = avals[j];
                double val2 = lin2.quickGetValue(i, aix[j]);
                double val = val1 * val2;
                if (val != 0 && lin3 != null)
                    val *= lin3.quickGetValue(i, aix[j]);
                kplus.execute2(kbuff, val);
            }
        }
        ret.quickSetValue(0, 0, kbuff._sum);
        ret.quickSetValue(0, 1, kbuff._correction);
    } else // tack+*
    {
        double[] c = ret.getDenseBlockValues();
        for (int i = rl; i < ru; i++) if (!a.isEmpty(i)) {
            int apos = a.pos(i);
            int alen = a.size(i);
            int[] aix = a.indexes(i);
            double[] avals = a.values(i);
            for (int j = apos; j < apos + alen; j++) {
                int colIx = aix[j];
                double val1 = avals[j];
                double val2 = lin2.quickGetValue(i, colIx);
                double val = val1 * val2;
                if (val != 0 && lin3 != null)
                    val *= lin3.quickGetValue(i, colIx);
                kbuff._sum = c[colIx];
                kbuff._correction = c[colIx + n];
                kplus.execute2(kbuff, val);
                c[colIx] = kbuff._sum;
                c[colIx + n] = kbuff._correction;
            }
        }
    }
}
Also used : ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus)

Aggregations

ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)14 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)10 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)7 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)7 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)6 ReduceRow (org.apache.sysml.runtime.functionobjects.ReduceRow)6 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)5 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)5 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)5 CM (org.apache.sysml.runtime.functionobjects.CM)4 Mean (org.apache.sysml.runtime.functionobjects.Mean)4 ReduceDiag (org.apache.sysml.runtime.functionobjects.ReduceDiag)4 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)4 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)3 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)2 IndexFunction (org.apache.sysml.runtime.functionobjects.IndexFunction)2 KahanFunction (org.apache.sysml.runtime.functionobjects.KahanFunction)2 Multiply (org.apache.sysml.runtime.functionobjects.Multiply)2 CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)2 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)2