Search in sources :

Example 1 with KahanPlus

use of org.apache.sysml.runtime.functionobjects.KahanPlus in project incubator-systemml by apache.

the class MVImputeAgent method mergeAndOutputTransformationMetadata.

/** 
	 * Method to merge map output transformation metadata. 
	 */
@Override
public void mergeAndOutputTransformationMetadata(Iterator<DistinctValue> values, String outputDir, int colID, FileSystem fs, TfUtils agents) throws IOException {
    double min = Double.MAX_VALUE;
    double max = -Double.MAX_VALUE;
    int nbins = 0;
    double d;
    long totalRecordCount = 0, totalValidCount = 0;
    String mvConstReplacement = null;
    DistinctValue val = new DistinctValue();
    String w = null;
    class MeanObject {

        double mean, correction;

        long count;

        MeanObject() {
        }

        public String toString() {
            return mean + "," + correction + "," + count;
        }
    }
    ;
    HashMap<Integer, MeanObject> mapMeans = new HashMap<Integer, MeanObject>();
    HashMap<Integer, CM_COV_Object> mapVars = new HashMap<Integer, CM_COV_Object>();
    boolean isImputed = false;
    boolean isScaled = false;
    boolean isBinned = false;
    while (values.hasNext()) {
        val.reset();
        val = values.next();
        w = val.getWord();
        if (w.startsWith(MEAN_PREFIX)) {
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            MeanObject mo = mapMeans.get(taskID);
            if (mo == null)
                mo = new MeanObject();
            mo.mean = UtilFunctions.parseToDouble(parts[2].split(",")[0]);
            // check if this attribute is scaled
            String s = parts[2].split(",")[1];
            if (s.equalsIgnoreCase("scmv"))
                isScaled = isImputed = true;
            else if (s.equalsIgnoreCase("scnomv"))
                isScaled = true;
            else
                isImputed = true;
            mapMeans.put(taskID, mo);
        } else if (w.startsWith(CORRECTION_PREFIX)) {
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            MeanObject mo = mapMeans.get(taskID);
            if (mo == null)
                mo = new MeanObject();
            mo.correction = UtilFunctions.parseToDouble(parts[2]);
            mapMeans.put(taskID, mo);
        } else if (w.startsWith(CONSTANT_PREFIX)) {
            isImputed = true;
            String[] parts = w.split("_");
            mvConstReplacement = parts[1];
        } else if (w.startsWith(COUNT_PREFIX)) {
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            MeanObject mo = mapMeans.get(taskID);
            if (mo == null)
                mo = new MeanObject();
            mo.count = UtilFunctions.parseToLong(parts[2]);
            totalValidCount += mo.count;
            mapMeans.put(taskID, mo);
        } else if (w.startsWith(TOTAL_COUNT_PREFIX)) {
            String[] parts = w.split("_");
            //int taskID = UtilFunctions.parseToInt(parts[1]);
            totalRecordCount += UtilFunctions.parseToLong(parts[2]);
        } else if (w.startsWith(VARIANCE_PREFIX)) {
            isScaled = true;
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            CM_COV_Object cm = decodeCMObj(parts[2]);
            mapVars.put(taskID, cm);
        } else if (w.startsWith(BinAgent.MIN_PREFIX)) {
            isBinned = true;
            d = UtilFunctions.parseToDouble(w.substring(BinAgent.MIN_PREFIX.length()));
            if (d < min)
                min = d;
        } else if (w.startsWith(BinAgent.MAX_PREFIX)) {
            isBinned = true;
            d = UtilFunctions.parseToDouble(w.substring(BinAgent.MAX_PREFIX.length()));
            if (d > max)
                max = d;
        } else if (w.startsWith(BinAgent.NBINS_PREFIX)) {
            isBinned = true;
            nbins = (int) UtilFunctions.parseToLong(w.substring(BinAgent.NBINS_PREFIX.length()));
        } else
            throw new RuntimeException("MVImputeAgent: Invalid prefix while merging map output: " + w);
    }
    // compute global mean across all map outputs
    KahanObject gmean = new KahanObject(0, 0);
    KahanPlus kp = KahanPlus.getKahanPlusFnObject();
    long gcount = 0;
    for (MeanObject mo : mapMeans.values()) {
        gcount = gcount + mo.count;
        if (gcount > 0) {
            double delta = mo.mean - gmean._sum;
            kp.execute2(gmean, delta * mo.count / gcount);
        //_meanFn.execute2(gmean, mo.mean*mo.count, gcount);
        }
    }
    // compute global variance across all map outputs
    CM_COV_Object gcm = new CM_COV_Object();
    try {
        for (CM_COV_Object cm : mapVars.values()) gcm = (CM_COV_Object) _varFn.execute(gcm, cm);
    } catch (DMLRuntimeException e) {
        throw new IOException(e);
    }
    // If the column is imputed with a constant, then adjust min and max based the value of the constant.
    if (isImputed && isBinned && mvConstReplacement != null) {
        double cst = UtilFunctions.parseToDouble(mvConstReplacement);
        if (cst < min)
            min = cst;
        if (cst > max)
            max = cst;
    }
    // write merged metadata
    if (isImputed) {
        String imputedValue = null;
        if (mvConstReplacement != null)
            imputedValue = mvConstReplacement;
        else
            imputedValue = Double.toString(gcount == 0 ? 0.0 : gmean._sum);
        writeTfMtd(colID, imputedValue, outputDir, fs, agents);
    }
    if (isBinned) {
        double binwidth = (max - min) / nbins;
        writeTfMtd(colID, Double.toString(min), Double.toString(max), Double.toString(binwidth), Integer.toString(nbins), outputDir, fs, agents);
    }
    if (isScaled) {
        try {
            if (totalValidCount != totalRecordCount) {
                // In the presence of missing values, the variance needs to be adjusted.
                // The mean does not need to be adjusted, when mv impute method is global_mean, 
                // since missing values themselves are replaced with gmean.
                long totalMissingCount = (totalRecordCount - totalValidCount);
                int idx = isApplicable(colID);
                if (idx != -1 && _mvMethodList[idx] == MVMethod.CONSTANT)
                    _meanFn.execute(gmean, UtilFunctions.parseToDouble(_replacementList[idx]), totalRecordCount);
                _varFn.execute(gcm, gmean._sum, totalMissingCount);
            }
            double mean = (gcount == 0 ? 0.0 : gmean._sum);
            double var = gcm.getRequiredResult(new CMOperator(_varFn, AggregateOperationTypes.VARIANCE));
            double sdev = (mapVars.size() > 0 ? Math.sqrt(var) : -1.0);
            writeTfMtd(colID, Double.toString(mean), Double.toString(sdev), outputDir, fs, agents);
        } catch (DMLRuntimeException e) {
            throw new IOException(e);
        }
    }
}
Also used : CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) HashMap(java.util.HashMap) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator)

Example 2 with KahanPlus

use of org.apache.sysml.runtime.functionobjects.KahanPlus in project incubator-systemml by apache.

the class ColGroupDDC1 method computeRowSums.

@Override
protected void computeRowSums(MatrixBlock result, KahanFunction kplus, int rl, int ru) {
    KahanObject kbuff = new KahanObject(0, 0);
    KahanPlus kplus2 = KahanPlus.getKahanPlusFnObject();
    double[] c = result.getDenseBlock();
    //pre-aggregate nnz per value tuple
    double[] vals = sumAllValues(kplus, kbuff, false);
    //for correctness in case of sqk+)
    for (int i = rl; i < ru; i++) {
        kbuff.set(c[2 * i], c[2 * i + 1]);
        kplus2.execute2(kbuff, vals[_data[i] & 0xFF]);
        c[2 * i] = kbuff._sum;
        c[2 * i + 1] = kbuff._correction;
    }
}
Also used : KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus)

Example 3 with KahanPlus

use of org.apache.sysml.runtime.functionobjects.KahanPlus in project incubator-systemml by apache.

the class ColGroupDDC2 method computeRowSums.

@Override
protected void computeRowSums(MatrixBlock result, KahanFunction kplus, int rl, int ru) {
    KahanObject kbuff = new KahanObject(0, 0);
    KahanPlus kplus2 = KahanPlus.getKahanPlusFnObject();
    double[] c = result.getDenseBlock();
    //pre-aggregate nnz per value tuple
    double[] vals = sumAllValues(kplus, kbuff, false);
    //for correctness in case of sqk+)
    for (int i = rl; i < ru; i++) {
        kbuff.set(c[2 * i], c[2 * i + 1]);
        kplus2.execute2(kbuff, vals[_data[i]]);
        c[2 * i] = kbuff._sum;
        c[2 * i + 1] = kbuff._correction;
    }
}
Also used : KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus)

Example 4 with KahanPlus

use of org.apache.sysml.runtime.functionobjects.KahanPlus in project incubator-systemml by apache.

the class ColGroupRLE method computeRowSums.

@Override
protected final void computeRowSums(MatrixBlock result, KahanFunction kplus, int rl, int ru) {
    KahanObject kbuff = new KahanObject(0, 0);
    KahanPlus kplus2 = KahanPlus.getKahanPlusFnObject();
    final int numVals = getNumValues();
    double[] c = result.getDenseBlock();
    if (ALLOW_CACHE_CONSCIOUS_ROWSUMS && LOW_LEVEL_OPT && numVals > 1 && _numRows > BitmapEncoder.BITMAP_BLOCK_SZ) {
        final int blksz = ColGroupOffset.WRITE_CACHE_BLKSZ / 2;
        //step 1: prepare position and value arrays
        //current pos / values per RLE list
        int[] astart = new int[numVals];
        int[] apos = skipScan(numVals, rl, astart);
        double[] aval = sumAllValues(kplus, kbuff, false);
        //step 2: cache conscious matrix-vector via horizontal scans 
        for (int bi = rl; bi < ru; bi += blksz) {
            int bimax = Math.min(bi + blksz, ru);
            //horizontal segment scan, incl pos maintenance
            for (int k = 0; k < numVals; k++) {
                int boff = _ptr[k];
                int blen = len(k);
                double val = aval[k];
                int bix = apos[k];
                int start = astart[k];
                //compute partial results, not aligned
                while (bix < blen) {
                    int lstart = _data[boff + bix];
                    int llen = _data[boff + bix + 1];
                    int from = Math.max(bi, start + lstart);
                    int to = Math.min(start + lstart + llen, bimax);
                    for (int rix = from; rix < to; rix++) {
                        kbuff.set(c[2 * rix], c[2 * rix + 1]);
                        kplus2.execute2(kbuff, val);
                        c[2 * rix] = kbuff._sum;
                        c[2 * rix + 1] = kbuff._correction;
                    }
                    if (start + lstart + llen >= bimax)
                        break;
                    start += lstart + llen;
                    bix += 2;
                }
                apos[k] = bix;
                astart[k] = start;
            }
        }
    } else {
        for (int k = 0; k < numVals; k++) {
            int boff = _ptr[k];
            int blen = len(k);
            double val = sumValues(k, kplus, kbuff);
            if (val != 0.0) {
                Pair<Integer, Integer> tmp = skipScanVal(k, rl);
                int bix = tmp.getKey();
                int curRunStartOff = tmp.getValue();
                int curRunEnd = tmp.getValue();
                for (; bix < blen && curRunEnd < ru; bix += 2) {
                    curRunStartOff = curRunEnd + _data[boff + bix];
                    curRunEnd = curRunStartOff + _data[boff + bix + 1];
                    for (int rix = curRunStartOff; rix < curRunEnd && rix < ru; rix++) {
                        kbuff.set(c[2 * rix], c[2 * rix + 1]);
                        kplus2.execute2(kbuff, val);
                        c[2 * rix] = kbuff._sum;
                        c[2 * rix + 1] = kbuff._correction;
                    }
                }
            }
        }
    }
}
Also used : KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus)

Example 5 with KahanPlus

use of org.apache.sysml.runtime.functionobjects.KahanPlus in project incubator-systemml by apache.

the class LibMatrixAgg method aggregateTernaryDense.

private static void aggregateTernaryDense(MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, MatrixBlock ret, IndexFunction ixFn, int rl, int ru) {
    //compute block operations
    KahanObject kbuff = new KahanObject(0, 0);
    KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
    double[] a = in1.denseBlock;
    double[] b1 = in2.denseBlock;
    //if null, literal 1
    double[] b2 = (in3 != null) ? in3.denseBlock : null;
    final int n = in1.clen;
    if (//tak+*
    ixFn instanceof ReduceAll) {
        for (int i = rl, ix = rl * n; i < ru; i++) for (int j = 0; j < n; j++, ix++) {
            double b2val = (b2 != null) ? b2[ix] : 1;
            double val = a[ix] * b1[ix] * b2val;
            kplus.execute2(kbuff, val);
        }
        ret.quickSetValue(0, 0, kbuff._sum);
        ret.quickSetValue(0, 1, kbuff._correction);
    } else //tack+*
    {
        double[] c = ret.getDenseBlock();
        for (int i = rl, ix = rl * n; i < ru; i++) for (int j = 0; j < n; j++, ix++) {
            double b2val = (b2 != null) ? b2[ix] : 1;
            double val = a[ix] * b1[ix] * b2val;
            kbuff._sum = c[j];
            kbuff._correction = c[j + n];
            kplus.execute2(kbuff, val);
            c[j] = kbuff._sum;
            c[j + n] = kbuff._correction;
        }
    }
}
Also used : ReduceAll(org.apache.sysml.runtime.functionobjects.ReduceAll) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus)

Aggregations

KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)28 KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)25 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)10 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)8 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)7 CM (org.apache.sysml.runtime.functionobjects.CM)5 KahanFunction (org.apache.sysml.runtime.functionobjects.KahanFunction)5 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)5 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)5 ReduceRow (org.apache.sysml.runtime.functionobjects.ReduceRow)5 Mean (org.apache.sysml.runtime.functionobjects.Mean)4 ReduceDiag (org.apache.sysml.runtime.functionobjects.ReduceDiag)4 ValueFunction (org.apache.sysml.runtime.functionobjects.ValueFunction)4 CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)4 Multiply (org.apache.sysml.runtime.functionobjects.Multiply)3 IOException (java.io.IOException)2 ArrayList (java.util.ArrayList)2 ExecutorService (java.util.concurrent.ExecutorService)2 Future (java.util.concurrent.Future)2 IndexFunction (org.apache.sysml.runtime.functionobjects.IndexFunction)2