Search in sources :

Example 1 with KahanObject

use of org.apache.sysml.runtime.instructions.cp.KahanObject in project incubator-systemml by apache.

the class MVImputeAgent method mergeAndOutputTransformationMetadata.

/** 
	 * Method to merge map output transformation metadata. 
	 */
@Override
public void mergeAndOutputTransformationMetadata(Iterator<DistinctValue> values, String outputDir, int colID, FileSystem fs, TfUtils agents) throws IOException {
    double min = Double.MAX_VALUE;
    double max = -Double.MAX_VALUE;
    int nbins = 0;
    double d;
    long totalRecordCount = 0, totalValidCount = 0;
    String mvConstReplacement = null;
    DistinctValue val = new DistinctValue();
    String w = null;
    class MeanObject {

        double mean, correction;

        long count;

        MeanObject() {
        }

        public String toString() {
            return mean + "," + correction + "," + count;
        }
    }
    ;
    HashMap<Integer, MeanObject> mapMeans = new HashMap<Integer, MeanObject>();
    HashMap<Integer, CM_COV_Object> mapVars = new HashMap<Integer, CM_COV_Object>();
    boolean isImputed = false;
    boolean isScaled = false;
    boolean isBinned = false;
    while (values.hasNext()) {
        val.reset();
        val = values.next();
        w = val.getWord();
        if (w.startsWith(MEAN_PREFIX)) {
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            MeanObject mo = mapMeans.get(taskID);
            if (mo == null)
                mo = new MeanObject();
            mo.mean = UtilFunctions.parseToDouble(parts[2].split(",")[0]);
            // check if this attribute is scaled
            String s = parts[2].split(",")[1];
            if (s.equalsIgnoreCase("scmv"))
                isScaled = isImputed = true;
            else if (s.equalsIgnoreCase("scnomv"))
                isScaled = true;
            else
                isImputed = true;
            mapMeans.put(taskID, mo);
        } else if (w.startsWith(CORRECTION_PREFIX)) {
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            MeanObject mo = mapMeans.get(taskID);
            if (mo == null)
                mo = new MeanObject();
            mo.correction = UtilFunctions.parseToDouble(parts[2]);
            mapMeans.put(taskID, mo);
        } else if (w.startsWith(CONSTANT_PREFIX)) {
            isImputed = true;
            String[] parts = w.split("_");
            mvConstReplacement = parts[1];
        } else if (w.startsWith(COUNT_PREFIX)) {
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            MeanObject mo = mapMeans.get(taskID);
            if (mo == null)
                mo = new MeanObject();
            mo.count = UtilFunctions.parseToLong(parts[2]);
            totalValidCount += mo.count;
            mapMeans.put(taskID, mo);
        } else if (w.startsWith(TOTAL_COUNT_PREFIX)) {
            String[] parts = w.split("_");
            //int taskID = UtilFunctions.parseToInt(parts[1]);
            totalRecordCount += UtilFunctions.parseToLong(parts[2]);
        } else if (w.startsWith(VARIANCE_PREFIX)) {
            isScaled = true;
            String[] parts = w.split("_");
            int taskID = UtilFunctions.parseToInt(parts[1]);
            CM_COV_Object cm = decodeCMObj(parts[2]);
            mapVars.put(taskID, cm);
        } else if (w.startsWith(BinAgent.MIN_PREFIX)) {
            isBinned = true;
            d = UtilFunctions.parseToDouble(w.substring(BinAgent.MIN_PREFIX.length()));
            if (d < min)
                min = d;
        } else if (w.startsWith(BinAgent.MAX_PREFIX)) {
            isBinned = true;
            d = UtilFunctions.parseToDouble(w.substring(BinAgent.MAX_PREFIX.length()));
            if (d > max)
                max = d;
        } else if (w.startsWith(BinAgent.NBINS_PREFIX)) {
            isBinned = true;
            nbins = (int) UtilFunctions.parseToLong(w.substring(BinAgent.NBINS_PREFIX.length()));
        } else
            throw new RuntimeException("MVImputeAgent: Invalid prefix while merging map output: " + w);
    }
    // compute global mean across all map outputs
    KahanObject gmean = new KahanObject(0, 0);
    KahanPlus kp = KahanPlus.getKahanPlusFnObject();
    long gcount = 0;
    for (MeanObject mo : mapMeans.values()) {
        gcount = gcount + mo.count;
        if (gcount > 0) {
            double delta = mo.mean - gmean._sum;
            kp.execute2(gmean, delta * mo.count / gcount);
        //_meanFn.execute2(gmean, mo.mean*mo.count, gcount);
        }
    }
    // compute global variance across all map outputs
    CM_COV_Object gcm = new CM_COV_Object();
    try {
        for (CM_COV_Object cm : mapVars.values()) gcm = (CM_COV_Object) _varFn.execute(gcm, cm);
    } catch (DMLRuntimeException e) {
        throw new IOException(e);
    }
    // If the column is imputed with a constant, then adjust min and max based the value of the constant.
    if (isImputed && isBinned && mvConstReplacement != null) {
        double cst = UtilFunctions.parseToDouble(mvConstReplacement);
        if (cst < min)
            min = cst;
        if (cst > max)
            max = cst;
    }
    // write merged metadata
    if (isImputed) {
        String imputedValue = null;
        if (mvConstReplacement != null)
            imputedValue = mvConstReplacement;
        else
            imputedValue = Double.toString(gcount == 0 ? 0.0 : gmean._sum);
        writeTfMtd(colID, imputedValue, outputDir, fs, agents);
    }
    if (isBinned) {
        double binwidth = (max - min) / nbins;
        writeTfMtd(colID, Double.toString(min), Double.toString(max), Double.toString(binwidth), Integer.toString(nbins), outputDir, fs, agents);
    }
    if (isScaled) {
        try {
            if (totalValidCount != totalRecordCount) {
                // In the presence of missing values, the variance needs to be adjusted.
                // The mean does not need to be adjusted, when mv impute method is global_mean, 
                // since missing values themselves are replaced with gmean.
                long totalMissingCount = (totalRecordCount - totalValidCount);
                int idx = isApplicable(colID);
                if (idx != -1 && _mvMethodList[idx] == MVMethod.CONSTANT)
                    _meanFn.execute(gmean, UtilFunctions.parseToDouble(_replacementList[idx]), totalRecordCount);
                _varFn.execute(gcm, gmean._sum, totalMissingCount);
            }
            double mean = (gcount == 0 ? 0.0 : gmean._sum);
            double var = gcm.getRequiredResult(new CMOperator(_varFn, AggregateOperationTypes.VARIANCE));
            double sdev = (mapVars.size() > 0 ? Math.sqrt(var) : -1.0);
            writeTfMtd(colID, Double.toString(mean), Double.toString(sdev), outputDir, fs, agents);
        } catch (DMLRuntimeException e) {
            throw new IOException(e);
        }
    }
}
Also used : CM_COV_Object(org.apache.sysml.runtime.instructions.cp.CM_COV_Object) HashMap(java.util.HashMap) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) KahanPlus(org.apache.sysml.runtime.functionobjects.KahanPlus) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator)

Example 2 with KahanObject

use of org.apache.sysml.runtime.instructions.cp.KahanObject in project incubator-systemml by apache.

the class MVImputeAgent method parseMethodsAndReplacments.

private void parseMethodsAndReplacments(JSONObject parsedSpec) throws JSONException {
    JSONArray mvspec = (JSONArray) parsedSpec.get(TfUtils.TXMETHOD_IMPUTE);
    _mvMethodList = new MVMethod[mvspec.size()];
    _replacementList = new String[mvspec.size()];
    _meanList = new KahanObject[mvspec.size()];
    _countList = new long[mvspec.size()];
    for (int i = 0; i < mvspec.size(); i++) {
        JSONObject mvobj = (JSONObject) mvspec.get(i);
        _mvMethodList[i] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
        if (_mvMethodList[i] == MVMethod.CONSTANT) {
            _replacementList[i] = mvobj.getString("value").toString();
        }
        _meanList[i] = new KahanObject(0, 0);
    }
}
Also used : JSONObject(org.apache.wink.json4j.JSONObject) JSONArray(org.apache.wink.json4j.JSONArray) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject)

Example 3 with KahanObject

use of org.apache.sysml.runtime.instructions.cp.KahanObject in project incubator-systemml by apache.

the class MVImputeAgent method outputTransformationMetadata.

public void outputTransformationMetadata(String outputDir, FileSystem fs, TfUtils agents) throws IOException {
    try {
        if (_colList != null)
            for (int i = 0; i < _colList.length; i++) {
                int colID = _colList[i];
                double imputedValue = Double.NaN;
                KahanObject gmean = null;
                if (_mvMethodList[i] == MVMethod.GLOBAL_MEAN) {
                    gmean = _meanList[i];
                    imputedValue = _meanList[i]._sum;
                    double mean = (_countList[i] == 0 ? 0.0 : _meanList[i]._sum);
                    writeTfMtd(colID, Double.toString(mean), outputDir, fs, agents);
                } else if (_mvMethodList[i] == MVMethod.CONSTANT) {
                    writeTfMtd(colID, _replacementList[i], outputDir, fs, agents);
                    if (_isMVScaled.get(i)) {
                        imputedValue = UtilFunctions.parseToDouble(_replacementList[i]);
                        // adjust the global mean, by combining gmean with "replacement" (weight = #missing values)
                        gmean = new KahanObject(_meanList[i]._sum, _meanList[i]._correction);
                        _meanFn.execute(gmean, imputedValue, agents.getValid());
                    }
                }
                if (_isMVScaled.get(i)) {
                    double sdev = -1.0;
                    if (_mvscMethodList[i] == MVMethod.GLOBAL_MODE) {
                        // Adjust variance with missing values
                        long totalMissingCount = (agents.getValid() - _countList[i]);
                        _varFn.execute(_varList[i], imputedValue, totalMissingCount);
                        double var = _varList[i].getRequiredResult(new CMOperator(_varFn, AggregateOperationTypes.VARIANCE));
                        sdev = Math.sqrt(var);
                    }
                    writeTfMtd(colID, Double.toString(gmean._sum), Double.toString(sdev), outputDir, fs, agents);
                }
            }
        if (_scnomvList != null)
            for (int i = 0; i < _scnomvList.length; i++) {
                int colID = _scnomvList[i];
                double mean = (_scnomvCountList[i] == 0 ? 0.0 : _scnomvMeanList[i]._sum);
                double sdev = -1.0;
                if (_scnomvMethodList[i] == MVMethod.GLOBAL_MODE) {
                    double var = _scnomvVarList[i].getRequiredResult(new CMOperator(_varFn, AggregateOperationTypes.VARIANCE));
                    sdev = Math.sqrt(var);
                }
                writeTfMtd(colID, Double.toString(mean), Double.toString(sdev), outputDir, fs, agents);
            }
    } catch (DMLRuntimeException e) {
        throw new IOException(e);
    }
}
Also used : KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject) IOException(java.io.IOException) CMOperator(org.apache.sysml.runtime.matrix.operators.CMOperator) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 4 with KahanObject

use of org.apache.sysml.runtime.instructions.cp.KahanObject in project incubator-systemml by apache.

the class SpoofCellwise method executeDenseAggSum.

private double executeDenseAggSum(double[] a, SideInput[] b, double[] scalars, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException {
    KahanFunction kplus = (KahanFunction) getAggFunction();
    KahanObject kbuff = new KahanObject(0, 0);
    for (int i = rl, ix = rl * n; i < ru; i++) for (int j = 0; j < n; j++, ix++) {
        double aval = (a != null) ? a[ix] : 0;
        if (aval != 0 || !sparseSafe)
            kplus.execute2(kbuff, genexec(aval, b, scalars, m, n, i, j));
    }
    return kbuff._sum;
}
Also used : KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject)

Example 5 with KahanObject

use of org.apache.sysml.runtime.instructions.cp.KahanObject in project incubator-systemml by apache.

the class SpoofCellwise method executeDenseRowAggSum.

private long executeDenseRowAggSum(double[] a, SideInput[] b, double[] scalars, double[] c, int m, int n, boolean sparseSafe, int rl, int ru) throws DMLRuntimeException {
    KahanFunction kplus = (KahanFunction) getAggFunction();
    KahanObject kbuff = new KahanObject(0, 0);
    long lnnz = 0;
    for (int i = rl, ix = rl * n; i < ru; i++) {
        kbuff.set(0, 0);
        for (int j = 0; j < n; j++, ix++) {
            double aval = (a != null) ? a[ix] : 0;
            if (aval != 0 || !sparseSafe)
                kplus.execute2(kbuff, genexec(aval, b, scalars, m, n, i, j));
        }
        lnnz += ((c[i] = kbuff._sum) != 0) ? 1 : 0;
    }
    return lnnz;
}
Also used : KahanFunction(org.apache.sysml.runtime.functionobjects.KahanFunction) KahanObject(org.apache.sysml.runtime.instructions.cp.KahanObject)

Aggregations

KahanObject (org.apache.sysml.runtime.instructions.cp.KahanObject)115 KahanPlus (org.apache.sysml.runtime.functionobjects.KahanPlus)49 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)28 KahanFunction (org.apache.sysml.runtime.functionobjects.KahanFunction)28 CM_COV_Object (org.apache.sysml.runtime.instructions.cp.CM_COV_Object)15 CM (org.apache.sysml.runtime.functionobjects.CM)14 Builtin (org.apache.sysml.runtime.functionobjects.Builtin)12 ReduceAll (org.apache.sysml.runtime.functionobjects.ReduceAll)10 DenseBlock (org.apache.sysml.runtime.matrix.data.DenseBlock)10 CMOperator (org.apache.sysml.runtime.matrix.operators.CMOperator)10 IOException (java.io.IOException)8 WeightedCell (org.apache.sysml.runtime.matrix.data.WeightedCell)8 AggregateOperator (org.apache.sysml.runtime.matrix.operators.AggregateOperator)8 KahanPlusSq (org.apache.sysml.runtime.functionobjects.KahanPlusSq)6 ReduceCol (org.apache.sysml.runtime.functionobjects.ReduceCol)6 ValueFunction (org.apache.sysml.runtime.functionobjects.ValueFunction)6 IJV (org.apache.sysml.runtime.matrix.data.IJV)6 ArrayList (java.util.ArrayList)4 ExecutorService (java.util.concurrent.ExecutorService)4 Future (java.util.concurrent.Future)4