use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTWorker method load.
@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<DTMasterParams, DTWorkerParams> context) {
this.count += 1;
if ((this.count) % 5000 == 0) {
LOG.info("Read {} records.", this.count);
}
// hashcode for fixed input split in train and validation
long hashcode = 0;
short[] inputs = new short[this.inputCount];
float ideal = 0f;
float significance = 1f;
// use guava Splitter to iterate only once
// use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
// the function in akka mode.
int index = 0, inputIndex = 0;
for (String input : this.splitter.split(currentValue.getWritable().toString())) {
if (index == this.columnConfigList.size()) {
// weight, how to process???
if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
significance = 1f;
break;
}
// check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 1f)
significance = input.length() == 0 ? 1f : NumberFormatUtils.getFloat(input, 1f);
// if invalid weight, set it to 1f and warning in log
if (Float.compare(significance, 0f) < 0) {
LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
significance = 1f;
}
// the last field is significance, break here
break;
} else {
ColumnConfig columnConfig = this.columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
ideal = getFloatValue(input);
} else {
if (!isAfterVarSelect) {
// no variable selected, good candidate but not meta and not target chose
if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, this.hasCandidates)) {
if (columnConfig.isNumerical()) {
float floatValue = getFloatValue(input);
// cast is safe as we limit max bin to Short.MAX_VALUE
short binIndex = (short) getBinIndex(floatValue, columnConfig.getBinBoundary());
inputs[inputIndex] = binIndex;
if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
}
} else if (columnConfig.isCategorical()) {
short shortValue = (short) (columnConfig.getBinCategory().size());
if (input.length() == 0) {
// empty
shortValue = (short) (columnConfig.getBinCategory().size());
} else {
Integer categoricalIndex = this.columnCategoryIndexMapping.get(columnConfig.getColumnNum()).get(input);
if (categoricalIndex == null) {
// invalid category, set to -1 for last index
shortValue = -1;
} else {
// cast is safe as we limit max bin to Short.MAX_VALUE
shortValue = (short) (categoricalIndex.intValue());
}
if (shortValue == -1) {
// not found
shortValue = (short) (columnConfig.getBinCategory().size());
}
}
inputs[inputIndex] = shortValue;
if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
}
}
hashcode = hashcode * 31 + input.hashCode();
inputIndex += 1;
}
} else {
// final select some variables but meta and target are not included
if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
if (columnConfig.isNumerical()) {
float floatValue = getFloatValue(input);
// cast is safe as we limit max bin to Short.MAX_VALUE
short binIndex = (short) getBinIndex(floatValue, columnConfig.getBinBoundary());
inputs[inputIndex] = binIndex;
if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
}
} else if (columnConfig.isCategorical()) {
// cast is safe as we limit max bin to Short.MAX_VALUE
short shortValue = (short) (columnConfig.getBinCategory().size());
if (input.length() == 0) {
// empty
shortValue = (short) (columnConfig.getBinCategory().size());
} else {
Integer categoricalIndex = this.columnCategoryIndexMapping.get(columnConfig.getColumnNum()).get(input);
if (categoricalIndex == null) {
// invalid category, set to -1 for last index
shortValue = -1;
} else {
// cast is safe as we limit max bin to Short.MAX_VALUE
shortValue = (short) (categoricalIndex.intValue());
}
if (shortValue == -1) {
// not found
shortValue = (short) (columnConfig.getBinCategory().size());
}
}
inputs[inputIndex] = shortValue;
if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
}
}
hashcode = hashcode * 31 + input.hashCode();
inputIndex += 1;
}
}
}
}
index += 1;
}
// is helped to quick find such issue.
if (inputIndex != inputs.length) {
String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputIndex + ", delimiter:" + delimiter + ".");
}
if (this.isOneVsAll) {
// if one vs all, update target value according to index of target
ideal = updateOneVsAllTargetValue(ideal);
}
// sample negative only logic here
if (modelConfig.getTrain().getSampleNegOnly()) {
if (this.modelConfig.isFixInitialInput()) {
// if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
// here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
// should take 1-0.8 to check endHashCode
int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
if (// regression or onevsall
(modelConfig.isRegression() || this.isOneVsAll) && // negative record
(int) (ideal + 0.01d) == 0 && isInRange(hashcode, startHashCode, endHashCode)) {
return;
}
} else {
// and if negative record do sampling out
if (// regression or onevsall
(modelConfig.isRegression() || this.isOneVsAll) && // negative record
(int) (ideal + 0.01d) == 0 && Double.compare(this.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
return;
}
}
}
float output = ideal;
float predict = ideal;
// up sampling logic, just add more weights while bagging sampling rate is still not changed
if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal, 1d) == 0) {
// Double.compare(ideal, 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
significance = significance * (this.upSampleRng.sample() + 1);
}
Data data = new Data(inputs, predict, output, output, significance);
boolean isValidation = false;
if (context.getAttachment() != null && context.getAttachment() instanceof Boolean) {
isValidation = (Boolean) context.getAttachment();
}
// split into validation and training data set according to validation rate
boolean isInTraining = this.addDataPairToDataSet(hashcode, data, isValidation);
// do bagging sampling only for training data
if (isInTraining) {
data.subsampleWeights = sampleWeights(data.label);
// if gbdt, only the 1st sampling value is used, if rf, use the 1st to denote some information, no need all
if (isPositive(data.label)) {
this.positiveSelectedTrainCount += data.subsampleWeights[0] * 1L;
} else {
this.negativeSelectedTrainCount += data.subsampleWeights[0] * 1L;
}
} else {
// for validation data, according bagging sampling logic, we may need to sampling validation data set, while
// validation data set are only used to compute validation error, not to do real sampling is ok.
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTWorker method getAllValidFeatures.
private List<Integer> getAllValidFeatures() {
List<Integer> features = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig config : columnConfigList) {
if (isAfterVarSelect) {
if (config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
// or categorical feature with getBinCategory().size() larger than 0
if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
features.add(config.getColumnNum());
}
}
} else {
if (!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
// or categorical feature with getBinCategory().size() larger than 0
if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
features.add(config.getColumnNum());
}
}
}
}
return features;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class UpdateBinningInfoReducer method reduce.
@Override
protected void reduce(IntWritable key, Iterable<BinningInfoWritable> values, Context context) throws IOException, InterruptedException {
long start = System.currentTimeMillis();
double sum = 0d;
double squaredSum = 0d;
double tripleSum = 0d;
double quarticSum = 0d;
double p25th = 0d;
double median = 0d;
double p75th = 0d;
long count = 0L, missingCount = 0L;
double min = Double.MAX_VALUE, max = Double.MIN_VALUE;
List<Double> binBoundaryList = null;
List<String> binCategories = null;
long[] binCountPos = null;
long[] binCountNeg = null;
double[] binWeightPos = null;
double[] binWeightNeg = null;
long[] binCountTotal = null;
int columnConfigIndex = key.get() >= this.columnConfigList.size() ? key.get() % this.columnConfigList.size() : key.get();
ColumnConfig columnConfig = this.columnConfigList.get(columnConfigIndex);
HyperLogLogPlus hyperLogLogPlus = null;
Set<String> fis = new HashSet<String>();
long totalCount = 0, invalidCount = 0, validNumCount = 0;
int binSize = 0;
for (BinningInfoWritable info : values) {
if (info.isEmpty()) {
// mapper has no stats, skip it
continue;
}
CountAndFrequentItemsWritable cfiw = info.getCfiw();
totalCount += cfiw.getCount();
invalidCount += cfiw.getInvalidCount();
validNumCount += cfiw.getValidNumCount();
fis.addAll(cfiw.getFrequetItems());
if (hyperLogLogPlus == null) {
hyperLogLogPlus = HyperLogLogPlus.Builder.build(cfiw.getHyperBytes());
} else {
try {
hyperLogLogPlus = (HyperLogLogPlus) hyperLogLogPlus.merge(HyperLogLogPlus.Builder.build(cfiw.getHyperBytes()));
} catch (CardinalityMergeException e) {
throw new RuntimeException(e);
}
}
if (columnConfig.isHybrid() && binBoundaryList == null && binCategories == null) {
binBoundaryList = info.getBinBoundaries();
binCategories = info.getBinCategories();
binSize = binBoundaryList.size() + binCategories.size();
binCountPos = new long[binSize + 1];
binCountNeg = new long[binSize + 1];
binWeightPos = new double[binSize + 1];
binWeightNeg = new double[binSize + 1];
binCountTotal = new long[binSize + 1];
} else if (columnConfig.isNumerical() && binBoundaryList == null) {
binBoundaryList = info.getBinBoundaries();
binSize = binBoundaryList.size();
binCountPos = new long[binSize + 1];
binCountNeg = new long[binSize + 1];
binWeightPos = new double[binSize + 1];
binWeightNeg = new double[binSize + 1];
binCountTotal = new long[binSize + 1];
} else if (columnConfig.isCategorical() && binCategories == null) {
binCategories = info.getBinCategories();
binSize = binCategories.size();
binCountPos = new long[binSize + 1];
binCountNeg = new long[binSize + 1];
binWeightPos = new double[binSize + 1];
binWeightNeg = new double[binSize + 1];
binCountTotal = new long[binSize + 1];
}
count += info.getTotalCount();
missingCount += info.getMissingCount();
// for numeric, such sums are OK, for categorical, such values are all 0, should be updated by using
// binCountPos and binCountNeg
sum += info.getSum();
squaredSum += info.getSquaredSum();
tripleSum += info.getTripleSum();
quarticSum += info.getQuarticSum();
if (Double.compare(max, info.getMax()) < 0) {
max = info.getMax();
}
if (Double.compare(min, info.getMin()) > 0) {
min = info.getMin();
}
for (int i = 0; i < (binSize + 1); i++) {
binCountPos[i] += info.getBinCountPos()[i];
binCountNeg[i] += info.getBinCountNeg()[i];
binWeightPos[i] += info.getBinWeightPos()[i];
binWeightNeg[i] += info.getBinWeightNeg()[i];
binCountTotal[i] += info.getBinCountPos()[i];
binCountTotal[i] += info.getBinCountNeg()[i];
}
}
if (columnConfig.isNumerical()) {
long p25Count = count / 4;
long medianCount = p25Count * 2;
long p75Count = p25Count * 3;
p25th = min;
median = min;
p75th = min;
int currentCount = 0;
for (int i = 0; i < binBoundaryList.size(); i++) {
double left = getCutoffBoundary(binBoundaryList.get(i), max, min);
double right = ((i == binBoundaryList.size() - 1) ? max : getCutoffBoundary(binBoundaryList.get(i + 1), max, min));
if (p25Count >= currentCount && p25Count < currentCount + binCountTotal[i]) {
p25th = ((p25Count - currentCount) / (double) binCountTotal[i]) * (right - left) + left;
}
if (medianCount >= currentCount && medianCount < currentCount + binCountTotal[i]) {
median = ((medianCount - currentCount) / (double) binCountTotal[i]) * (right - left) + left;
}
if (p75Count >= currentCount && p75Count < currentCount + binCountTotal[i]) {
p75th = ((p75Count - currentCount) / (double) binCountTotal[i]) * (right - left) + left;
// when get 75 percentile stop it
break;
}
currentCount += binCountTotal[i];
}
LOG.info("Coloumn num is {}, p25 value is {}, median value is {}, p75 value is {}", columnConfig.getColumnNum(), p25th, median, p75th);
}
LOG.info("Coloumn num is {}, columnType value is {}, cateMaxNumBin is {}, binCategory size is {}", columnConfig.getColumnNum(), columnConfig.getColumnType(), modelConfig.getStats().getCateMaxNumBin(), (CollectionUtils.isNotEmpty(columnConfig.getBinCategory()) ? columnConfig.getBinCategory().size() : 0));
// To merge categorical binning
if (columnConfig.isCategorical() && modelConfig.getStats().getCateMaxNumBin() > 0 && CollectionUtils.isNotEmpty(binCategories) && binCategories.size() > modelConfig.getStats().getCateMaxNumBin()) {
// only category size large then expected max bin number
CateBinningStats cateBinningStats = rebinCategoricalValues(new CateBinningStats(binCategories, binCountPos, binCountNeg, binWeightPos, binWeightNeg));
LOG.info("For variable - {}, {} bins is rebined to {} bins", columnConfig.getColumnName(), binCategories.size(), cateBinningStats.binCategories.size());
binCategories = cateBinningStats.binCategories;
binCountPos = cateBinningStats.binCountPos;
binCountNeg = cateBinningStats.binCountNeg;
binWeightPos = cateBinningStats.binWeightPos;
binWeightNeg = cateBinningStats.binWeightNeg;
}
double[] binPosRate;
if (modelConfig.isRegression()) {
binPosRate = computePosRate(binCountPos, binCountNeg);
} else {
// for multiple classfication, use rate of categories to compute a value
binPosRate = computeRateForMultiClassfication(binCountPos);
}
String binBounString = null;
if (columnConfig.isHybrid()) {
if (binCategories.size() > this.maxCateSize) {
LOG.warn("Column {} {} with invalid bin category size.", key.get(), columnConfig.getColumnName(), binCategories.size());
return;
}
binBounString = binBoundaryList.toString();
binBounString += Constants.HYBRID_BIN_STR_DILIMETER + Base64Utils.base64Encode("[" + StringUtils.join(binCategories, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR) + "]");
} else if (columnConfig.isCategorical()) {
if (binCategories.size() > this.maxCateSize) {
LOG.warn("Column {} {} with invalid bin category size.", key.get(), columnConfig.getColumnName(), binCategories.size());
return;
}
binBounString = Base64Utils.base64Encode("[" + StringUtils.join(binCategories, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR) + "]");
// recompute such value for categorical variables
min = Double.MAX_VALUE;
max = Double.MIN_VALUE;
sum = 0d;
squaredSum = 0d;
for (int i = 0; i < binPosRate.length; i++) {
if (!Double.isNaN(binPosRate[i])) {
if (Double.compare(max, binPosRate[i]) < 0) {
max = binPosRate[i];
}
if (Double.compare(min, binPosRate[i]) > 0) {
min = binPosRate[i];
}
long binCount = binCountPos[i] + binCountNeg[i];
sum += binPosRate[i] * binCount;
double squaredVal = binPosRate[i] * binPosRate[i];
squaredSum += squaredVal * binCount;
tripleSum += squaredVal * binPosRate[i] * binCount;
quarticSum += squaredVal * squaredVal * binCount;
}
}
} else {
if (binBoundaryList.size() == 0) {
LOG.warn("Column {} {} with invalid bin boundary size.", key.get(), columnConfig.getColumnName(), binBoundaryList.size());
return;
}
binBounString = binBoundaryList.toString();
}
ColumnMetrics columnCountMetrics = null;
ColumnMetrics columnWeightMetrics = null;
if (modelConfig.isRegression()) {
columnCountMetrics = ColumnStatsCalculator.calculateColumnMetrics(binCountNeg, binCountPos);
columnWeightMetrics = ColumnStatsCalculator.calculateColumnMetrics(binWeightNeg, binWeightPos);
}
// To make it be consistent with SPDT, missingCount is excluded to compute mean, stddev ...
long realCount = this.statsExcludeMissingValue ? (count - missingCount) : count;
double mean = sum / realCount;
double stdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / realCount + EPS) / (realCount - 1)));
double aStdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / realCount + EPS) / realCount));
double skewness = ColumnStatsCalculator.computeSkewness(realCount, mean, aStdDev, sum, squaredSum, tripleSum);
double kurtosis = ColumnStatsCalculator.computeKurtosis(realCount, mean, aStdDev, sum, squaredSum, tripleSum, quarticSum);
sb.append(key.get()).append(Constants.DEFAULT_DELIMITER).append(binBounString).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binCountNeg)).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binCountPos)).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(new double[0])).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binPosRate)).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? "" : df.format(columnCountMetrics.getKs())).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? "" : df.format(columnCountMetrics.getIv())).append(Constants.DEFAULT_DELIMITER).append(df.format(max)).append(Constants.DEFAULT_DELIMITER).append(df.format(min)).append(Constants.DEFAULT_DELIMITER).append(df.format(mean)).append(Constants.DEFAULT_DELIMITER).append(df.format(stdDev)).append(Constants.DEFAULT_DELIMITER).append(columnConfig.getColumnType().toString()).append(Constants.DEFAULT_DELIMITER).append(median).append(Constants.DEFAULT_DELIMITER).append(missingCount).append(Constants.DEFAULT_DELIMITER).append(count).append(Constants.DEFAULT_DELIMITER).append(missingCount * 1.0d / count).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binWeightNeg)).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binWeightPos)).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? "" : columnCountMetrics.getWoe()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? "" : columnWeightMetrics.getWoe()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? "" : columnWeightMetrics.getKs()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? "" : columnWeightMetrics.getIv()).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? Arrays.toString(new double[binSize + 1]) : columnCountMetrics.getBinningWoe().toString()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? Arrays.toString(new double[binSize + 1]) : // bin weighted WOE
columnWeightMetrics.getBinningWoe().toString()).append(Constants.DEFAULT_DELIMITER).append(// skewness
skewness).append(Constants.DEFAULT_DELIMITER).append(// kurtosis
kurtosis).append(Constants.DEFAULT_DELIMITER).append(// total count
totalCount).append(Constants.DEFAULT_DELIMITER).append(// invalid count
invalidCount).append(Constants.DEFAULT_DELIMITER).append(// valid num count
validNumCount).append(Constants.DEFAULT_DELIMITER).append(// cardinality
hyperLogLogPlus.cardinality()).append(Constants.DEFAULT_DELIMITER).append(// frequent items
Base64Utils.base64Encode(limitedFrequentItems(fis))).append(Constants.DEFAULT_DELIMITER).append(// the 25 percentile value
p25th).append(Constants.DEFAULT_DELIMITER).append(p75th);
outputValue.set(sb.toString());
context.write(NullWritable.get(), outputValue);
sb.delete(0, sb.length());
LOG.debug("Time:{}", (System.currentTimeMillis() - start));
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CorrelationMapper method map.
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
double[] dValues = null;
if (!this.dataPurifier.isFilter(valueStr)) {
return;
}
long startO = System.currentTimeMillis();
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CNT_AFTER_FILTER").increment(1L);
// make sampling work in correlation
if (Math.random() >= modelConfig.getStats().getSampleRate()) {
return;
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CORRELATION_CNT").increment(1L);
dValues = getDoubleArrayByRawArray(CommonUtils.split(valueStr, this.dataSetDelimiter));
count += 1L;
if (count % 2000L == 0) {
LOG.info("Current records: {} in thread {}.", count, Thread.currentThread().getName());
}
for (int i = 0; i < columnConfigList.size(); i++) {
ColumnConfig columnConfig = columnConfigList.get(i);
if (columnConfig.getColumnFlag() == ColumnFlag.Meta || (hasCandidates && !ColumnFlag.Candidate.equals(columnConfig.getColumnFlag()))) {
continue;
}
CorrelationWritable cw = CorrelationMultithreadedMapper.finalCorrelationMap.get(columnConfig.getColumnNum());
synchronized (cw) {
cw.setColumnIndex(i);
cw.setCount(cw.getCount() + 1d);
cw.setSum(cw.getSum() + dValues[i]);
double squaredSum = dValues[i] * dValues[i];
cw.setSumSquare(cw.getSumSquare() + squaredSum);
double[] xySum = cw.getXySum();
if (xySum == null) {
xySum = new double[columnConfigList.size()];
cw.setXySum(xySum);
}
double[] xxSum = cw.getXxSum();
if (xxSum == null) {
xxSum = new double[columnConfigList.size()];
cw.setXxSum(xxSum);
}
double[] yySum = cw.getYySum();
if (yySum == null) {
yySum = new double[columnConfigList.size()];
cw.setYySum(yySum);
}
double[] adjustCount = cw.getAdjustCount();
if (adjustCount == null) {
adjustCount = new double[columnConfigList.size()];
cw.setAdjustCount(adjustCount);
}
double[] adjustSumX = cw.getAdjustSumX();
if (adjustSumX == null) {
adjustSumX = new double[columnConfigList.size()];
cw.setAdjustSumX(adjustSumX);
}
double[] adjustSumY = cw.getAdjustSumY();
if (adjustSumY == null) {
adjustSumY = new double[columnConfigList.size()];
cw.setAdjustSumY(adjustSumY);
}
for (int j = (this.isComputeAll ? 0 : i); j < columnConfigList.size(); j++) {
ColumnConfig otherColumnConfig = columnConfigList.get(j);
if ((otherColumnConfig.getColumnFlag() != ColumnFlag.Target) && ((otherColumnConfig.getColumnFlag() == ColumnFlag.Meta) || (hasCandidates && !ColumnFlag.Candidate.equals(otherColumnConfig.getColumnFlag())))) {
continue;
}
// only do stats on both valid values
if (dValues[i] != Double.MIN_VALUE && dValues[j] != Double.MIN_VALUE) {
xySum[j] += dValues[i] * dValues[j];
xxSum[j] += squaredSum;
yySum[j] += dValues[j] * dValues[j];
adjustCount[j] += 1d;
adjustSumX[j] += dValues[i];
adjustSumY[j] += dValues[j];
}
}
}
LOG.debug("running time is {}ms in thread {}", (System.currentTimeMillis() - startO), Thread.currentThread().getName());
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class BinaryDTSerializer method getColumnMapping.
private static Map<Integer, Integer> getColumnMapping(List<ColumnConfig> columnConfigList) {
Map<Integer, Integer> columnMapping = new HashMap<Integer, Integer>(columnConfigList.size(), 1f);
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(columnConfigList);
boolean isAfterVarSelect = inputOutputIndex[3] == 1 ? true : false;
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
int index = 0;
for (int i = 0; i < columnConfigList.size(); i++) {
ColumnConfig columnConfig = columnConfigList.get(i);
if (!isAfterVarSelect) {
if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
columnMapping.put(columnConfig.getColumnNum(), index);
index += 1;
}
} else {
if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
columnMapping.put(columnConfig.getColumnNum(), index);
index += 1;
}
}
}
return columnMapping;
}
Aggregations