Search in sources :

Example 11 with CM_COV_Object

use of org.apache.sysml.runtime.instructions.cp.CM_COV_Object in project incubator-systemml by apache.

the class CovarianceSPInstruction method processInstruction.

@Override
public void processInstruction(ExecutionContext ec) {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    COVOperator cop = ((COVOperator) _optr);
    // get input
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(input2.getName());
    // process central moment instruction
    CM_COV_Object cmobj = null;
    if (input3 == null) {
        // w/o weights
        cmobj = in1.join(in2).values().map(new RDDCOVFunction(cop)).fold(new CM_COV_Object(), new RDDCOVReduceFunction(cop));
    } else {
        // with weights
        JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = sec.getBinaryBlockRDDHandleForVariable(input3.getName());
        cmobj = in1.join(in2).join(in3).values().map(new RDDCOVWeightsFunction(cop)).fold(new CM_COV_Object(), new RDDCOVReduceFunction(cop));
    }
    // create scalar output (no lineage information required)
    double val = cmobj.getRequiredResult(_optr);
    ec.setScalarOutput(output.getName(), new DoubleObject(val));
}
Also used : CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) COVOperator(org.apache.sysml.runtime.matrix.operators.COVOperator) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)

Example 12 with CM_COV_Object

use of org.apache.sysml.runtime.instructions.cp.CM_COV_Object in project incubator-systemml by apache.

the class PerformGroupByAggInCombiner method call.

@Override
public WeightedCell call(WeightedCell value1, WeightedCell value2) throws Exception {
    WeightedCell outCell = new WeightedCell();
    CM_COV_Object cmObj = new CM_COV_Object();
    if (// everything except sum
    _op instanceof CMOperator) {
        if (((CMOperator) _op).isPartialAggregateOperator()) {
            cmObj.reset();
            // cmFn.get(key.getTag());
            CM lcmFn = CM.getCMFnObject(((CMOperator) _op).aggOpType);
            // partial aggregate cm operator
            lcmFn.execute(cmObj, value1.getValue(), value1.getWeight());
            lcmFn.execute(cmObj, value2.getValue(), value2.getWeight());
            outCell.setValue(cmObj.getRequiredPartialResult(_op));
            outCell.setWeight(cmObj.getWeight());
        } else // forward tuples to reducer
        {
            throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInReducer");
        }
    } else if (// sum
    _op instanceof AggregateOperator) {
        AggregateOperator aggop = (AggregateOperator) _op;
        if (aggop.correctionExists) {
            KahanObject buffer = new KahanObject(aggop.initialValue, 0);
            KahanPlus.getKahanPlusFnObject();
            // partial aggregate with correction
            aggop.increOp.fn.execute(buffer, value1.getValue() * value1.getWeight());
            aggop.increOp.fn.execute(buffer, value2.getValue() * value2.getWeight());
            outCell.setValue(buffer._sum);
            outCell.setWeight(1);
        } else // no correction
        {
            double v = aggop.initialValue;
            // partial aggregate without correction
            v = aggop.increOp.fn.execute(v, value1.getValue() * value1.getWeight());
            v = aggop.increOp.fn.execute(v, value2.getValue() * value2.getWeight());
            outCell.setValue(v);
            outCell.setWeight(1);
        }
    } else
        throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + _op);
    return outCell;
}
Also used : WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) CM(org.apache.sysml.runtime.functionobjects.CM) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 13 with CM_COV_Object

use of org.apache.sysml.runtime.instructions.cp.CM_COV_Object in project incubator-systemml by apache.

the class PerformGroupByAggInReducer method call.

@Override
public WeightedCell call(Iterable<WeightedCell> kv) throws Exception {
    WeightedCell outCell = new WeightedCell();
    CM_COV_Object cmObj = new CM_COV_Object();
    if (// everything except sum
    op instanceof CMOperator) {
        cmObj.reset();
        // cmFn.get(key.getTag());
        CM lcmFn = CM.getCMFnObject(((CMOperator) op).aggOpType);
        if (((CMOperator) op).isPartialAggregateOperator()) {
            throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInCombiner");
        } else // forward tuples to reducer
        {
            for (WeightedCell value : kv) lcmFn.execute(cmObj, value.getValue(), value.getWeight());
            outCell.setValue(cmObj.getRequiredResult(op));
            outCell.setWeight(1);
        }
    } else if (// sum
    op instanceof AggregateOperator) {
        AggregateOperator aggop = (AggregateOperator) op;
        if (aggop.correctionExists) {
            KahanObject buffer = new KahanObject(aggop.initialValue, 0);
            KahanPlus.getKahanPlusFnObject();
            // partial aggregate with correction
            for (WeightedCell value : kv) aggop.increOp.fn.execute(buffer, value.getValue() * value.getWeight());
            outCell.setValue(buffer._sum);
            outCell.setWeight(1);
        } else // no correction
        {
            double v = aggop.initialValue;
            // partial aggregate without correction
            for (WeightedCell value : kv) v = aggop.increOp.fn.execute(v, value.getValue() * value.getWeight());
            outCell.setValue(v);
            outCell.setWeight(1);
        }
    } else
        throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + op);
    return outCell;
}
Also used : WeightedCell(org.apache.sysml.runtime.matrix.data.WeightedCell) CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) AggregateOperator(org.apache.sysml.runtime.matrix.operators.AggregateOperator) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) CM(org.apache.sysml.runtime.functionobjects.CM) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 14 with CM_COV_Object

use of org.apache.sysml.runtime.instructions.cp.CM_COV_Object in project incubator-systemml by apache.

the class MatrixBlock method incrementalAggregate.

@Override
public void incrementalAggregate(AggregateOperator aggOp, MatrixValue newWithCorrection) {
    // assert(aggOp.correctionExists);
    MatrixBlock newWithCor = checkType(newWithCorrection);
    KahanObject buffer = new KahanObject(0, 0);
    if (aggOp.correctionLocation == CorrectionLocationType.LASTROW) {
        if (aggOp.increOp.fn instanceof KahanPlus) {
            LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
        } else {
            for (int r = 0; r < rlen - 1; r++) for (int c = 0; c < clen; c++) {
                buffer._sum = this.quickGetValue(r, c);
                buffer._correction = this.quickGetValue(r + 1, c);
                buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.quickGetValue(r, c), newWithCor.quickGetValue(r + 1, c));
                quickSetValue(r, c, buffer._sum);
                quickSetValue(r + 1, c, buffer._correction);
            }
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTCOLUMN) {
        if (aggOp.increOp.fn instanceof Builtin && (((Builtin) (aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX || ((Builtin) (aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX)) {
            // modified, the other needs to be changed to match.
            for (int r = 0; r < rlen; r++) {
                double currMaxValue = quickGetValue(r, 1);
                long newMaxIndex = (long) newWithCor.quickGetValue(r, 0);
                double newMaxValue = newWithCor.quickGetValue(r, 1);
                double update = aggOp.increOp.fn.execute(newMaxValue, currMaxValue);
                if (2.0 == update) {
                    // Return value of 2 ==> both values the same, break ties
                    // in favor of higher index.
                    long curMaxIndex = (long) quickGetValue(r, 0);
                    quickSetValue(r, 0, Math.max(curMaxIndex, newMaxIndex));
                } else if (1.0 == update) {
                    // Return value of 1 ==> new value is better; use its index
                    quickSetValue(r, 0, newMaxIndex);
                    quickSetValue(r, 1, newMaxValue);
                } else {
                // Other return value ==> current answer is best
                }
            }
        // *** END HACK ***
        } else {
            if (aggOp.increOp.fn instanceof KahanPlus) {
                LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
            } else {
                for (int r = 0; r < rlen; r++) for (int c = 0; c < clen - 1; c++) {
                    buffer._sum = this.quickGetValue(r, c);
                    buffer._correction = this.quickGetValue(r, c + 1);
                    buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.quickGetValue(r, c), newWithCor.quickGetValue(r, c + 1));
                    quickSetValue(r, c, buffer._sum);
                    quickSetValue(r, c + 1, buffer._correction);
                }
            }
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTTWOROWS) {
        double n, n2, mu2;
        for (int r = 0; r < rlen - 2; r++) for (int c = 0; c < clen; c++) {
            buffer._sum = this.quickGetValue(r, c);
            n = this.quickGetValue(r + 1, c);
            buffer._correction = this.quickGetValue(r + 2, c);
            mu2 = newWithCor.quickGetValue(r, c);
            n2 = newWithCor.quickGetValue(r + 1, c);
            n = n + n2;
            double toadd = (mu2 - buffer._sum) * n2 / n;
            buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
            quickSetValue(r, c, buffer._sum);
            quickSetValue(r + 1, c, n);
            quickSetValue(r + 2, c, buffer._correction);
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTTWOCOLUMNS) {
        double n, n2, mu2;
        for (int r = 0; r < rlen; r++) for (int c = 0; c < clen - 2; c++) {
            buffer._sum = this.quickGetValue(r, c);
            n = this.quickGetValue(r, c + 1);
            buffer._correction = this.quickGetValue(r, c + 2);
            mu2 = newWithCor.quickGetValue(r, c);
            n2 = newWithCor.quickGetValue(r, c + 1);
            n = n + n2;
            double toadd = (mu2 - buffer._sum) * n2 / n;
            buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
            quickSetValue(r, c, buffer._sum);
            quickSetValue(r, c + 1, n);
            quickSetValue(r, c + 2, buffer._correction);
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS && aggOp.increOp.fn instanceof CM && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
        // create buffers to store results
        CM_COV_Object cbuff_curr = new CM_COV_Object();
        CM_COV_Object cbuff_part = new CM_COV_Object();
        // perform incremental aggregation
        for (int r = 0; r < rlen - 4; r++) for (int c = 0; c < clen; c++) {
            // extract current values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_curr.w = quickGetValue(r + 2, c);
            // m2
            cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1);
            // mean
            cbuff_curr.mean._sum = quickGetValue(r + 1, c);
            cbuff_curr.m2._correction = quickGetValue(r + 3, c);
            cbuff_curr.mean._correction = quickGetValue(r + 4, c);
            // extract partial values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_part.w = newWithCor.quickGetValue(r + 2, c);
            // m2
            cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1);
            // mean
            cbuff_part.mean._sum = newWithCor.quickGetValue(r + 1, c);
            cbuff_part.m2._correction = newWithCor.quickGetValue(r + 3, c);
            cbuff_part.mean._correction = newWithCor.quickGetValue(r + 4, c);
            // calculate incremental aggregated variance
            cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
            // store updated values: { var | mean, count, m2 correction, mean correction }
            double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
            quickSetValue(r, c, var);
            // mean
            quickSetValue(r + 1, c, cbuff_curr.mean._sum);
            // count
            quickSetValue(r + 2, c, cbuff_curr.w);
            quickSetValue(r + 3, c, cbuff_curr.m2._correction);
            quickSetValue(r + 4, c, cbuff_curr.mean._correction);
        }
    } else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS && aggOp.increOp.fn instanceof CM && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
        // create buffers to store results
        CM_COV_Object cbuff_curr = new CM_COV_Object();
        CM_COV_Object cbuff_part = new CM_COV_Object();
        // perform incremental aggregation
        for (int r = 0; r < rlen; r++) for (int c = 0; c < clen - 4; c++) {
            // extract current values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_curr.w = quickGetValue(r, c + 2);
            // m2
            cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1);
            // mean
            cbuff_curr.mean._sum = quickGetValue(r, c + 1);
            cbuff_curr.m2._correction = quickGetValue(r, c + 3);
            cbuff_curr.mean._correction = quickGetValue(r, c + 4);
            // extract partial values: { var | mean, count, m2 correction, mean correction }
            // note: m2 = var * (n - 1)
            // count
            cbuff_part.w = newWithCor.quickGetValue(r, c + 2);
            // m2
            cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1);
            // mean
            cbuff_part.mean._sum = newWithCor.quickGetValue(r, c + 1);
            cbuff_part.m2._correction = newWithCor.quickGetValue(r, c + 3);
            cbuff_part.mean._correction = newWithCor.quickGetValue(r, c + 4);
            // calculate incremental aggregated variance
            cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
            // store updated values: { var | mean, count, m2 correction, mean correction }
            double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
            quickSetValue(r, c, var);
            // mean
            quickSetValue(r, c + 1, cbuff_curr.mean._sum);
            // count
            quickSetValue(r, c + 2, cbuff_curr.w);
            quickSetValue(r, c + 3, cbuff_curr.m2._correction);
            quickSetValue(r, c + 4, cbuff_curr.mean._correction);
        }
    } else
        throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correctionLocation);
}
Also used : CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) CM(org.apache.sysml.runtime.functionobjects.CM) Builtin(org.apache.sysml.runtime.functionobjects.Builtin) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 15 with CM_COV_Object

use of org.apache.sysml.runtime.instructions.cp.CM_COV_Object in project incubator-systemml by apache.

the class MatrixBlock method cmOperations.

public CM_COV_Object cmOperations(CMOperator op) {
    // dimension check for input column vectors
    if (this.getNumColumns() != 1) {
        throw new DMLRuntimeException("Central Moment can not be computed on [" + this.getNumRows() + "," + this.getNumColumns() + "] matrix.");
    }
    CM_COV_Object cmobj = new CM_COV_Object();
    // we get a NaN due to 0/0 on reading out the required result)
    if (isEmptyBlock(false)) {
        op.fn.execute(cmobj, 0.0, getNumRows());
        return cmobj;
    }
    int nzcount = 0;
    if (// SPARSE
    sparse && sparseBlock != null) {
        for (int r = 0; r < Math.min(rlen, sparseBlock.numRows()); r++) {
            if (sparseBlock.isEmpty(r))
                continue;
            int apos = sparseBlock.pos(r);
            int alen = sparseBlock.size(r);
            double[] avals = sparseBlock.values(r);
            for (int i = apos; i < apos + alen; i++) {
                op.fn.execute(cmobj, avals[i]);
                nzcount++;
            }
        }
        // account for zeros in the vector
        op.fn.execute(cmobj, 0.0, this.getNumRows() - nzcount);
    } else if (// DENSE
    denseBlock != null) {
        // always vector (see check above)
        double[] a = getDenseBlockValues();
        for (int i = 0; i < rlen; i++) op.fn.execute(cmobj, a[i]);
    }
    return cmobj;
}
Also used : CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Aggregations

CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)17 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)10 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)8 CM (org.apache.sysml.runtime.functionobjects.CM)6 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)4 CMOperator (org.apache.sysml.runtime.matrix.operators.CMOperator)4 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)3 IOException (java.io.IOException)2 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)2 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)2 Mean (org.apache.sysml.runtime.functionobjects.Mean)2 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)2 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)2 ReduceDiag (org.apache.sysml.runtime.functionobjects.ReduceDiag)2 ReduceRow (org.apache.sysml.runtime.functionobjects.ReduceRow)2 DoubleObject (org.apache.sysml.runtime.instructions.cp.DoubleObject)2 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)2 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)2 WeightedCell (org.apache.sysml.runtime.matrix.data.WeightedCell)2 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)2