use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class BinaryDTSerializer method save.
public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<List<TreeNode>> baggingTrees, String loss, int inputCount, OutputStream output) throws IOException {
DataOutputStream fos = null;
try {
fos = new DataOutputStream(new GZIPOutputStream(output));
// version
fos.writeInt(CommonConstants.TREE_FORMAT_VERSION);
fos.writeUTF(modelConfig.getAlgorithm());
fos.writeUTF(loss);
fos.writeBoolean(modelConfig.isClassification());
fos.writeBoolean(modelConfig.getTrain().isOneVsAll());
fos.writeInt(inputCount);
Map<Integer, String> columnIndexNameMapping = new HashMap<Integer, String>();
Map<Integer, List<String>> columnIndexCategoricalListMapping = new HashMap<Integer, List<String>>();
Map<Integer, Double> numericalMeanMapping = new HashMap<Integer, Double>();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect()) {
columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName());
}
if (columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) {
columnIndexCategoricalListMapping.put(columnConfig.getColumnNum(), columnConfig.getBinCategory());
}
if (columnConfig.isNumerical() && columnConfig.getMean() != null) {
numericalMeanMapping.put(columnConfig.getColumnNum(), columnConfig.getMean());
}
}
if (columnIndexNameMapping.size() == 0) {
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
columnIndexNameMapping.put(columnConfig.getColumnNum(), columnConfig.getColumnName());
}
}
}
// serialize numericalMeanMapping
fos.writeInt(numericalMeanMapping.size());
for (Entry<Integer, Double> entry : numericalMeanMapping.entrySet()) {
fos.writeInt(entry.getKey());
// for some feature, it is null mean value, it is not selected, just set to 0d to avoid NPE
fos.writeDouble(entry.getValue() == null ? 0d : entry.getValue());
}
// serialize columnIndexNameMapping
fos.writeInt(columnIndexNameMapping.size());
for (Entry<Integer, String> entry : columnIndexNameMapping.entrySet()) {
fos.writeInt(entry.getKey());
fos.writeUTF(entry.getValue());
}
// serialize columnIndexCategoricalListMapping
fos.writeInt(columnIndexCategoricalListMapping.size());
for (Entry<Integer, List<String>> entry : columnIndexCategoricalListMapping.entrySet()) {
List<String> categories = entry.getValue();
if (categories != null) {
fos.writeInt(entry.getKey());
fos.writeInt(categories.size());
for (String category : categories) {
// in read part logic should be changed also to readByte not readUTF according to the marker
if (category.length() < Constants.MAX_CATEGORICAL_VAL_LEN) {
fos.writeUTF(category);
} else {
// marker here
fos.writeShort(UTF_BYTES_MARKER);
byte[] bytes = category.getBytes("UTF-8");
fos.writeInt(bytes.length);
for (int i = 0; i < bytes.length; i++) {
fos.writeByte(bytes[i]);
}
}
}
}
}
Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList);
fos.writeInt(columnMapping.size());
for (Entry<Integer, Integer> entry : columnMapping.entrySet()) {
fos.writeInt(entry.getKey());
fos.writeInt(entry.getValue());
}
// after model version 4 (>=4), IndependentTreeModel support bagging, here write a default RF/GBT size 1
fos.writeInt(baggingTrees.size());
for (int i = 0; i < baggingTrees.size(); i++) {
List<TreeNode> trees = baggingTrees.get(i);
int treeLength = trees.size();
fos.writeInt(treeLength);
for (TreeNode treeNode : trees) {
treeNode.write(fos);
}
}
} catch (IOException e) {
LOG.error("Error in writing output.", e);
} finally {
IOUtils.closeStream(fos);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTMaster method getStatsMem.
private long getStatsMem(List<Integer> subsetFeatures) {
long statsMem = 0L;
List<Integer> tempFeatures = subsetFeatures;
if (subsetFeatures.size() == 0) {
tempFeatures = getAllFeatureList(this.columnConfigList, this.isAfterVarSelect);
}
for (Integer columnNum : tempFeatures) {
ColumnConfig config = this.columnConfigList.get(columnNum);
// 2 is overhead to avoid oom
if (config.isNumerical()) {
statsMem += config.getBinBoundary().size() * this.impurity.getStatsSize() * 8L * 2;
} else if (config.isCategorical()) {
statsMem += (config.getBinCategory().size() + 1) * this.impurity.getStatsSize() * 8L * 2;
}
}
// times worker number to avoid oom in master, as combinable DTWorkerParams, use one third of worker number
statsMem = statsMem * this.workerNumber / 2;
return statsMem;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTMaster method getAllFeatureList.
private List<Integer> getAllFeatureList(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
List<Integer> features = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig config : columnConfigList) {
if (isAfterVarSelect) {
if (config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
// or categorical feature with getBinCategory().size() larger than 0
if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
features.add(config.getColumnNum());
}
}
} else {
if (!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
// or categorical feature with getBinCategory().size() larger than 0
if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
features.add(config.getColumnNum());
}
}
}
}
return features;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class LogisticRegressionWorker method load.
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<LogisticRegressionParams, LogisticRegressionParams> context) {
++this.count;
if ((this.count) % 100000 == 0) {
LOG.info("Read {} records.", this.count);
}
String line = currentValue.getWritable().toString();
float[] inputData = new float[inputNum];
float[] outputData = new float[outputNum];
int index = 0, inputIndex = 0, outputIndex = 0;
long hashcode = 0;
double significance = CommonConstants.DEFAULT_SIGNIFICANCE_VALUE;
boolean hasCandidates = CommonUtils.hasCandidateColumns(this.columnConfigList);
String[] fields = Lists.newArrayList(this.splitter.split(line)).toArray(new String[0]);
int pos = 0;
for (pos = 0; pos < fields.length; ) {
String unit = fields[pos];
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
float floatValue = unit.length() == 0 ? 0f : NumberFormatUtils.getFloat(unit, 0f);
// no idea about why NaN in input data, we should process it as missing value TODO , according to norm type
floatValue = (Float.isNaN(floatValue) || Double.isNaN(floatValue)) ? 0f : floatValue;
if (pos == fields.length - 1) {
// weight, how to process???
if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
significance = 1d;
// break here if we reach weight column which is last column
break;
}
// check here to avoid bad performance in failed NumberFormatUtils.getDouble(input, 1)
significance = unit.length() == 0 ? 1f : NumberFormatUtils.getDouble(unit, 1d);
// if invalid weight, set it to 1f and warning in log
if (Double.compare(significance, 0d) < 0) {
LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
significance = 1d;
}
// the last field is significance, break here
break;
} else {
ColumnConfig columnConfig = this.columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
outputData[outputIndex++] = floatValue;
pos++;
} else {
if (this.inputNum == this.candidateNum) {
// no variable selected, good candidate but not meta and not target choosed
if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
inputData[inputIndex++] = floatValue;
hashcode = hashcode * 31 + Float.valueOf(floatValue).hashCode();
}
pos++;
} else {
if (columnConfig.isFinalSelect()) {
if (columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) {
for (int k = 0; k < columnConfig.getBinBoundary().size() + 1; k++) {
String tval = fields[pos];
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
// no idea about why NaN in input data, we should process it as missing value TODO ,
// according to norm type
fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
inputData[inputIndex++] = fval;
pos++;
}
} else if (columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT))) {
for (int k = 0; k < columnConfig.getBinCategory().size() + 1; k++) {
String tval = fields[pos];
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 0f)
float fval = tval.length() == 0 ? 0f : NumberFormatUtils.getFloat(tval, 0f);
// no idea about why NaN in input data, we should process it as missing value TODO ,
// according to norm type
fval = (Float.isNaN(fval) || Double.isNaN(fval)) ? 0f : fval;
inputData[inputIndex++] = fval;
pos++;
}
} else {
inputData[inputIndex++] = floatValue;
pos++;
}
hashcode = hashcode * 31 + Double.valueOf(floatValue).hashCode();
} else {
if (!CommonUtils.isToNormVariable(columnConfig, hasCandidates, modelConfig.isRegression())) {
pos += 1;
} else if (columnConfig.isNumerical() && modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT) && columnConfig.getBinBoundary() != null && columnConfig.getBinBoundary().size() > 0) {
pos += (columnConfig.getBinBoundary().size() + 1);
} else if (columnConfig.isCategorical() && (modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ZSCALE_ONEHOT) || modelConfig.getNormalizeType().equals(ModelNormalizeConf.NormType.ONEHOT)) && columnConfig.getBinCategory().size() > 0) {
pos += (columnConfig.getBinCategory().size() + 1);
} else {
pos += 1;
}
}
}
}
}
index += 1;
}
if (index != this.columnConfigList.size() || pos != fields.length - 1) {
throw new RuntimeException("Wrong data indexing. ColumnConfig index = " + index + ", while it should be " + columnConfigList.size() + ". " + "Data Pos = " + pos + ", while it should be " + (fields.length - 1));
}
// is helped to quick find such issue.
if (inputIndex != inputData.length) {
String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputData.length + ", parsing size:" + inputIndex + ", delimiter:" + delimiter + ".");
}
// sample negative only logic here
if (modelConfig.getTrain().getSampleNegOnly()) {
if (this.modelConfig.isFixInitialInput()) {
// if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
// here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
// should take 1-0.8 to check endHashCode
int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (outputData[0] + 0.01d) == // negative record
0 && isInRange(hashcode, startHashCode, endHashCode)) {
return;
}
} else {
// if negative record
if ((modelConfig.isRegression() || // regression or
(modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll())) && // onevsall
(int) (outputData[0] + 0.01d) == // negative record
0 && Double.compare(Math.random(), this.modelConfig.getBaggingSampleRate()) >= 0) {
return;
}
}
}
Data data = new Data(inputData, outputData, significance);
// up sampling logic, just add more weights while bagging sampling rate is still not changed
if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(outputData[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoids sample count to 0
data.setSignificance(data.significance * (this.upSampleRng.sample() + 1));
}
boolean isValidation = false;
if (context.getAttachment() != null && context.getAttachment() instanceof Boolean) {
isValidation = (Boolean) context.getAttachment();
}
boolean isInTraining = addDataPairToDataSet(hashcode, data, isValidation);
// do bagging sampling only for training data
if (isInTraining) {
float subsampleWeights = sampleWeights(outputData[0]);
if (isPositive(outputData[0])) {
this.positiveSelectedTrainCount += subsampleWeights * 1L;
} else {
this.negativeSelectedTrainCount += subsampleWeights * 1L;
}
// set weights to significance, if 0, significance will be 0, that is bagging sampling
data.setSignificance(data.significance * subsampleWeights);
} else {
// for validation data, according bagging sampling logic, we may need to sampling validation data set, while
// validation data set are only used to compute validation error, not to do real sampling is ok.
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class FastCorrelationMapper method map.
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
double[] dValues = null;
if (!this.dataPurifier.isFilter(valueStr)) {
return;
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CNT_AFTER_FILTER").increment(1L);
// make sampling work in correlation
if (Math.random() >= modelConfig.getStats().getSampleRate()) {
return;
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CORRELATION_CNT").increment(1L);
dValues = getDoubleArrayByRawArray(CommonUtils.split(valueStr, this.dataSetDelimiter));
count += 1L;
if (count % 2000L == 0) {
LOG.info("Current records: {} in thread {}.", count, Thread.currentThread().getName());
}
long startO = System.currentTimeMillis();
for (int i = 0; i < columnConfigList.size(); i++) {
long start = System.currentTimeMillis();
ColumnConfig columnConfig = columnConfigList.get(i);
if (columnConfig.getColumnFlag() == ColumnFlag.Meta || (hasCandidates && !ColumnFlag.Candidate.equals(columnConfig.getColumnFlag()))) {
continue;
}
CorrelationWritable cw = this.correlationMap.get(columnConfig.getColumnNum());
if (cw == null) {
cw = new CorrelationWritable();
this.correlationMap.put(columnConfig.getColumnNum(), cw);
}
cw.setColumnIndex(i);
cw.setCount(cw.getCount() + 1d);
cw.setSum(cw.getSum() + dValues[i]);
double squaredSum = dValues[i] * dValues[i];
cw.setSumSquare(cw.getSumSquare() + squaredSum);
double[] xySum = cw.getXySum();
if (xySum == null) {
xySum = new double[columnConfigList.size()];
cw.setXySum(xySum);
}
double[] xxSum = cw.getXxSum();
if (xxSum == null) {
xxSum = new double[columnConfigList.size()];
cw.setXxSum(xxSum);
}
double[] yySum = cw.getYySum();
if (yySum == null) {
yySum = new double[columnConfigList.size()];
cw.setYySum(yySum);
}
double[] adjustCount = cw.getAdjustCount();
if (adjustCount == null) {
adjustCount = new double[columnConfigList.size()];
cw.setAdjustCount(adjustCount);
}
double[] adjustSumX = cw.getAdjustSumX();
if (adjustSumX == null) {
adjustSumX = new double[columnConfigList.size()];
cw.setAdjustSumX(adjustSumX);
}
double[] adjustSumY = cw.getAdjustSumY();
if (adjustSumY == null) {
adjustSumY = new double[columnConfigList.size()];
cw.setAdjustSumY(adjustSumY);
}
if (i % 1000 == 0) {
LOG.debug("running time 1 is {}ms in thread {}", (System.currentTimeMillis() - start), Thread.currentThread().getName());
}
start = System.currentTimeMillis();
for (int j = (this.isComputeAll ? 0 : i); j < columnConfigList.size(); j++) {
ColumnConfig otherColumnConfig = columnConfigList.get(j);
if ((otherColumnConfig.getColumnFlag() != ColumnFlag.Target) && ((otherColumnConfig.getColumnFlag() == ColumnFlag.Meta) || (hasCandidates && !ColumnFlag.Candidate.equals(otherColumnConfig.getColumnFlag())))) {
continue;
}
// only do stats on both valid values
if (dValues[i] != Double.MIN_VALUE && dValues[j] != Double.MIN_VALUE) {
xySum[j] += dValues[i] * dValues[j];
xxSum[j] += squaredSum;
yySum[j] += dValues[j] * dValues[j];
adjustCount[j] += 1d;
adjustSumX[j] += dValues[i];
adjustSumY[j] += dValues[j];
}
}
if (i % 1000 == 0) {
LOG.debug("running time 2 is {}ms in thread {}", (System.currentTimeMillis() - start), Thread.currentThread().getName());
}
}
LOG.debug("running time is {}ms in thread {}", (System.currentTimeMillis() - startO), Thread.currentThread().getName());
}
Aggregations