Search in sources :

Example 6 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class BinaryDTSerializer method save.

public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<List<TreeNode>> baggingTrees, String loss, int inputCount, OutputStream output) throws IOException {
    DataOutputStream fos = null;
    try {
        fos = new DataOutputStream(new GZIPOutputStream(output));
        // version
        fos.writeInt(CommonConstants.TREE_FORMAT_VERSION);
        fos.writeUTF(modelConfig.getAlgorithm());
        fos.writeUTF(loss);
        fos.writeBoolean(modelConfig.isClassification());
        fos.writeBoolean(modelConfig.getTrain().isOneVsAll());
        fos.writeInt(inputCount);
        Map<Integer, String> columnIndexNameMapping = new HashMap<Integer, String>();
        Map<Integer, List<String>> columnIndexCategoricalListMapping = new HashMap<Integer, List<String>>();
        Map<Integer, Double> numericalMeanMapping = new HashMap<Integer, Double>();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect()) {
                columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName());
            }
            if (columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) {
                columnIndexCategoricalListMapping.put(columnConfig.getColumnNum(), columnConfig.getBinCategory());
            }
            if (columnConfig.isNumerical() && columnConfig.getMean() != null) {
                numericalMeanMapping.put(columnConfig.getColumnNum(), columnConfig.getMean());
            }
        }
        if (columnIndexNameMapping.size() == 0) {
            boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
            for (ColumnConfig columnConfig : columnConfigList) {
                if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                    columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName());
                }
            }
        }
        // serialize numericalMeanMapping
        fos.writeInt(numericalMeanMapping.size());
        for (Entry<Integer, Double> entry : numericalMeanMapping.entrySet()) {
            fos.writeInt(entry.getKey());
            // for some feature, it is null mean value, it is not selected, just set to 0d to avoid NPE
            fos.writeDouble(entry.getValue() == null ? 0d : entry.getValue());
        }
        // serialize columnIndexNameMapping
        fos.writeInt(columnIndexNameMapping.size());
        for (Entry<Integer, String> entry : columnIndexNameMapping.entrySet()) {
            fos.writeInt(entry.getKey());
            fos.writeUTF(entry.getValue());
        }
        // serialize columnIndexCategoricalListMapping
        fos.writeInt(columnIndexCategoricalListMapping.size());
        for (Entry<Integer, List<String>> entry : columnIndexCategoricalListMapping.entrySet()) {
            List<String> categories = entry.getValue();
            if (categories != null) {
                fos.writeInt(entry.getKey());
                fos.writeInt(categories.size());
                for (String category : categories) {
                    // in read part logic should be changed also to readByte not readUTF according to the marker
                    if (category.length() < Constants.MAX_CATEGORICAL_VAL_LEN) {
                        fos.writeUTF(category);
                    } else {
                        // marker here
                        fos.writeShort(UTF_BYTES_MARKER);
                        byte[] bytes = category.getBytes("UTF-8");
                        fos.writeInt(bytes.length);
                        for (int i = 0; i < bytes.length; i++) {
                            fos.writeByte(bytes[i]);
                        }
                    }
                }
            }
        }
        Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList);
        fos.writeInt(columnMapping.size());
        for (Entry<Integer, Integer> entry : columnMapping.entrySet()) {
            fos.writeInt(entry.getKey());
            fos.writeInt(entry.getValue());
        }
        // after model version 4 (>=4), IndependentTreeModel support bagging, here write a default RF/GBT size 1
        fos.writeInt(baggingTrees.size());
        for (int i = 0; i < baggingTrees.size(); i++) {
            List<TreeNode> trees = baggingTrees.get(i);
            int treeLength = trees.size();
            fos.writeInt(treeLength);
            for (TreeNode treeNode : trees) {
                treeNode.write(fos);
            }
        }
    } catch (IOException e) {
        LOG.error("Error in writing output.", e);
    } finally {
        IOUtils.closeStream(fos);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap) DataOutputStream(java.io.DataOutputStream) IOException(java.io.IOException) GZIPOutputStream(java.util.zip.GZIPOutputStream) List(java.util.List)

Example 7 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTMaster method getStatsMem.

private long getStatsMem(List<Integer> subsetFeatures) {
    long statsMem = 0L;
    List<Integer> tempFeatures = subsetFeatures;
    if (subsetFeatures.size() == 0) {
        tempFeatures = getAllFeatureList(this.columnConfigList, this.isAfterVarSelect);
    }
    for (Integer columnNum : tempFeatures) {
        ColumnConfig config = this.columnConfigList.get(columnNum);
        // 2 is overhead to avoid oom
        if (config.isNumerical()) {
            statsMem += config.getBinBoundary().size() * this.impurity.getStatsSize() * 8L * 2;
        } else if (config.isCategorical()) {
            statsMem += (config.getBinCategory().size() + 1) * this.impurity.getStatsSize() * 8L * 2;
        }
    }
    // times worker number to avoid oom in master, as combinable DTWorkerParams, use one third of worker number
    statsMem = statsMem * this.workerNumber / 2;
    return statsMem;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 8 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTMaster method getAllFeatureList.

private List<Integer> getAllFeatureList(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
    List<Integer> features = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (isAfterVarSelect) {
            if (config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
                // or categorical feature with getBinCategory().size() larger than 0
                if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
                    features.add(config.getColumnNum());
                }
            }
        } else {
            if (!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
                // or categorical feature with getBinCategory().size() larger than 0
                if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
                    features.add(config.getColumnNum());
                }
            }
        }
    }
    return features;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList)

Example 9 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class LogisticRegressionWorker method load.

@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
    ++this.count;
    if ((this.count) % 100000 == 0) {
        LOG.info("Read {} records.", this.count);
    }
    String line = currentValue.getWritable().toString();
    float[] inputData = new float[inputNum];
    float[] outputData = new float[outputNum];
    int index = 0, inputIndex = 0, outputIndex = 0;
    long hashcode = 0;
    double significance = CommonConstants.DEFAULT_SIGNIFICANCE_VALUE;
    boolean hasCandidates = CommonUtils.hasCandidateColumns(this.columnConfigList);
    String[] fields = Lists.newArrayList(this.splitter.split(line)).toArray(new String[0]);
    int pos = 0;
    for (pos = 0; pos < fields.length; ) {
        String unit = fields[pos];
        // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
        float floatValue = unit.length() == 0 ? 0f : NumberFormatUtils.getFloat(unit, 0f);
        // no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
        floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
        if (pos == fields.length - 1) {
            // weight, how to process???
            if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
                significance = 1d;
                // break here if we reach weight column which is last column
                break;
            }
            // check here to avoid bad performance in failed NumberFormatUtils.getDouble(input, 1)
            significance = unit.length() == 0 ? 1f : NumberFormatUtils.getDouble(unit, 1d);
            // if invalid weight, set it to 1f and warning in log
            if (Double.compare(significance, 0d) < 0) {
                LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
                significance = 1d;
            }
            // the last field is significance, break here
            break;
        } else {
            ColumnConfig columnConfig = this.columnConfigList.get(index);
            if (columnConfig != null && columnConfig.isTarget()) {
                outputData[outputIndex++] = floatValue;
                pos++;
            } else {
                if (this.inputNum == this.candidateNum) {
                    // no variable selected, good candidate but not meta and not target choosed
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                        inputData[inputIndex++] = floatValue;
                        hashcode = hashcode * 31 + Float.valueOf(floatValue).hashCode();
                    }
                    pos++;
                } else {
                    if (columnConfig.isFinalSelect()) {
                        if (columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) {
                            for (int k = 0; k < columnConfig.getBinBoundary().size() + 1; k++) {
                                String tval = fields[pos];
                                // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                                float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
                                // no idea about why NaN in input data, we should process it as missing value TODO ,
                                // according to norm type
                                fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
                                inputData[inputIndex++] = fval;
                                pos++;
                            }
                        } else if (columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT))) {
                            for (int k = 0; k < columnConfig.getBinCategory().size() + 1; k++) {
                                String tval = fields[pos];
                                // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
                                float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
                                // no idea about why NaN in input data, we should process it as missing value TODO ,
                                // according to norm type
                                fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
                                inputData[inputIndex++] = fval;
                                pos++;
                            }
                        } else {
                            inputData[inputIndex++] = floatValue;
                            pos++;
                        }
                        hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
                    } else {
                        if (!CommonUtils.isToNormVariable(columnConfig, hasCandidates, modelConfig.isRegression())) {
                            pos += 1;
                        } else if (columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT) && columnConfig.getBinBoundary() != null && columnConfig.getBinBoundary().size() > 0) {
                            pos += (columnConfig.getBinBoundary().size() + 1);
                        } else if (columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) && columnConfig.getBinCategory().size() > 0) {
                            pos += (columnConfig.getBinCategory().size() + 1);
                        } else {
                            pos += 1;
                        }
                    }
                }
            }
        }
        index += 1;
    }
    if (index != this.columnConfigList.size() || pos != fields.length - 1) {
        throw new RuntimeException("Wrong data indexing. ColumnConfig index = " + index + ", while it should be " + columnConfigList.size() + ". " + "Data Pos = " + pos + ", while it should be " + (fields.length - 1));
    }
    // is helped to quick find such issue.
    if (inputIndex != inputData.length) {
        String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
        throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputData.length + ", parsing size:" + inputIndex + ", delimiter:" + delimiter + ".");
    }
    // sample negative only logic here
    if (modelConfig.getTrain().getSampleNegOnly()) {
        if (this.modelConfig.isFixInitialInput()) {
            // if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
            int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
            // here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
            // should take 1-0.8 to check endHashCode
            int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (outputData[0] + 0.01d) == // negative record
            0 && isInRange(hashcode, startHashCode, endHashCode)) {
                return;
            }
        } else {
            // if negative record
            if ((modelConfig.isRegression() || // regression or
            (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
            (int) (outputData[0] + 0.01d) == // negative record
            0 && Double.compare(Math.random(), this.modelConfig.getBaggingSampleRate()) >= 0) {
                return;
            }
        }
    }
    Data data = new Data(inputData, outputData, significance);
    // up sampling logic, just add more weights while bagging sampling rate is still not changed
    if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(outputData[0], 1d) == 0) {
        // Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoids sample count to 0
        data.setSignificance(data.significance * (this.upSampleRng.sample() + 1));
    }
    boolean isValidation = false;
    if (context.getAttachment() != null && context.getAttachment() instanceof Boolean) {
        isValidation = (Boolean) context.getAttachment();
    }
    boolean isInTraining = addDataPairToDataSet(hashcode, data, isValidation);
    // do bagging sampling only for training data
    if (isInTraining) {
        float subsampleWeights = sampleWeights(outputData[0]);
        if (isPositive(outputData[0])) {
            this.positiveSelectedTrainCount += subsampleWeights * 1L;
        } else {
            this.negativeSelectedTrainCount += subsampleWeights * 1L;
        }
        // set weights to significance, if 0, significance will be 0, that is bagging sampling
        data.setSignificance(data.significance * subsampleWeights);
    } else {
    // for validation data, according bagging sampling logic, we may need to sampling validation data set, while
    // validation data set are only used to compute validation error, not to do real sampling is ok.
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 10 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class FastCorrelationMapper method map.

@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
    String valueStr = value.toString();
    if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
        LOG.warn("Empty input.");
        return;
    }
    double[] dValues = null;
    if (!this.dataPurifier.isFilter(valueStr)) {
        return;
    }
    context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CNT_AFTER_FILTER").increment(1L);
    // make sampling work in correlation
    if (Math.random() >= modelConfig.getStats().getSampleRate()) {
        return;
    }
    context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CORRELATION_CNT").increment(1L);
    dValues = getDoubleArrayByRawArray(CommonUtils.split(valueStr, this.dataSetDelimiter));
    count += 1L;
    if (count % 2000L == 0) {
        LOG.info("Current records: {} in thread {}.", count, Thread.currentThread().getName());
    }
    long startO = System.currentTimeMillis();
    for (int i = 0; i < columnConfigList.size(); i++) {
        long start = System.currentTimeMillis();
        ColumnConfig columnConfig = columnConfigList.get(i);
        if (columnConfig.getColumnFlag() == ColumnFlag.Meta || (hasCandidates && !ColumnFlag.Candidate.equals(columnConfig.getColumnFlag()))) {
            continue;
        }
        CorrelationWritable cw = this.correlationMap.get(columnConfig.getColumnNum());
        if (cw == null) {
            cw = new CorrelationWritable();
            this.correlationMap.put(columnConfig.getColumnNum(), cw);
        }
        cw.setColumnIndex(i);
        cw.setCount(cw.getCount() + 1d);
        cw.setSum(cw.getSum() + dValues[i]);
        double squaredSum = dValues[i] * dValues[i];
        cw.setSumSquare(cw.getSumSquare() + squaredSum);
        double[] xySum = cw.getXySum();
        if (xySum == null) {
            xySum = new double[columnConfigList.size()];
            cw.setXySum(xySum);
        }
        double[] xxSum = cw.getXxSum();
        if (xxSum == null) {
            xxSum = new double[columnConfigList.size()];
            cw.setXxSum(xxSum);
        }
        double[] yySum = cw.getYySum();
        if (yySum == null) {
            yySum = new double[columnConfigList.size()];
            cw.setYySum(yySum);
        }
        double[] adjustCount = cw.getAdjustCount();
        if (adjustCount == null) {
            adjustCount = new double[columnConfigList.size()];
            cw.setAdjustCount(adjustCount);
        }
        double[] adjustSumX = cw.getAdjustSumX();
        if (adjustSumX == null) {
            adjustSumX = new double[columnConfigList.size()];
            cw.setAdjustSumX(adjustSumX);
        }
        double[] adjustSumY = cw.getAdjustSumY();
        if (adjustSumY == null) {
            adjustSumY = new double[columnConfigList.size()];
            cw.setAdjustSumY(adjustSumY);
        }
        if (i % 1000 == 0) {
            LOG.debug("running time 1 is {}ms in thread {}", (System.currentTimeMillis() - start), Thread.currentThread().getName());
        }
        start = System.currentTimeMillis();
        for (int j = (this.isComputeAll ? 0 : i); j < columnConfigList.size(); j++) {
            ColumnConfig otherColumnConfig = columnConfigList.get(j);
            if ((otherColumnConfig.getColumnFlag() != ColumnFlag.Target) && ((otherColumnConfig.getColumnFlag() == ColumnFlag.Meta) || (hasCandidates && !ColumnFlag.Candidate.equals(otherColumnConfig.getColumnFlag())))) {
                continue;
            }
            // only do stats on both valid values
            if (dValues[i] != Double.MIN_VALUE && dValues[j] != Double.MIN_VALUE) {
                xySum[j] += dValues[i] * dValues[j];
                xxSum[j] += squaredSum;
                yySum[j] += dValues[j] * dValues[j];
                adjustCount[j] += 1d;
                adjustSumX[j] += dValues[i];
                adjustSumY[j] += dValues[j];
            }
        }
        if (i % 1000 == 0) {
            LOG.debug("running time 2 is {}ms in thread {}", (System.currentTimeMillis() - start), Thread.currentThread().getName());
        }
    }
    LOG.debug("running time is {}ms in thread {}", (System.currentTimeMillis() - startO), Thread.currentThread().getName());
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3