use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class VarSelectReducer method cleanup.
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
Collections.sort(this.results, new Comparator<Pair>() {
@Override
public int compare(Pair o1, Pair o2) {
return Double.compare(o2.value.getRms(), o1.value.getRms());
}
});
LOG.debug("Final Results:{}", this.results);
int candidates = this.filterNum;
if (candidates <= 0) {
if (Constants.FILTER_BY_ST.equalsIgnoreCase(this.filterBy) || Constants.FILTER_BY_SE.equalsIgnoreCase(this.filterBy)) {
candidates = (int) (this.inputNodeCount * (1.0f - this.filterOutRatio));
} else {
// wrapper by A
candidates = (int) (this.inputNodeCount * (this.filterOutRatio));
}
}
LOG.info("Candidates count is {}", candidates);
for (int i = 0; i < this.results.size(); i++) {
Pair pair = this.results.get(i);
this.outputKey.set(pair.key + "");
if (i < candidates) {
context.write(this.outputKey, OUTPUT_VALUE);
}
// for thousands of features, here using 'new' ok
StringBuilder sb = new StringBuilder(100);
// after supporting segments, the columns will expansion. the columnId may not the position
// in columnConfigList. It's safe to columnId to search (make sure columnNum == columnId)
ColumnConfig columnConfig = CommonUtils.getColumnConfig(this.columnConfigList, (int) pair.key);
sb.append(columnConfig.getColumnName()).append("\t").append(pair.value.getMean()).append("\t").append(pair.value.getRms()).append("\t").append(pair.value.getVariance());
this.outputValue.set(sb.toString());
this.mos.write(Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME, this.outputKey, this.outputValue);
}
this.mos.close();
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class VarSelectMapper method map.
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
recordCount += 1L;
int index = 0, inputsIndex = 0, outputsIndex = 0;
for (String input : this.splitter.split(value.toString())) {
double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
if (index == columnConfigList.size()) {
break;
} else {
ColumnConfig columnConfig = columnConfigList.get(index);
if (columnConfig != null && columnConfig.isTarget()) {
this.outputs[outputsIndex++] = doubleValue;
} else {
if (this.featureSet != null && this.featureSet.contains(columnConfig.getColumnNum())) {
inputs[inputsIndex] = doubleValue;
columnIndexes[inputsIndex++] = columnConfig.getColumnNum();
}
}
}
index++;
}
this.inputsMLData.setData(this.inputs);
// compute candidate model score , cache first layer of sum values in such call method, cache flag here is true
double candidateModelScore = cacheNetwork.compute(inputsMLData, true, -1).getData()[0];
for (int i = 0; i < this.inputs.length; i++) {
// cache flag is false to reuse cache sum of first layer of values.
double currentModelScore = cacheNetwork.compute(inputsMLData, false, i).getData()[0];
double diff = 0d;
if (Constants.FILTER_BY_ST.equalsIgnoreCase(this.filterBy)) {
// ST
diff = this.outputs[0] - currentModelScore;
} else {
// SE
diff = candidateModelScore - currentModelScore;
}
ColumnInfo columnInfo = this.results.get(this.columnIndexes[i]);
if (columnInfo == null) {
columnInfo = new ColumnInfo();
columnInfo.setSumScoreDiff(Math.abs(diff));
columnInfo.setSumSquareScoreDiff(power2(diff));
} else {
columnInfo.setSumScoreDiff(columnInfo.getSumScoreDiff() + Math.abs(diff));
columnInfo.setSumSquareScoreDiff(columnInfo.getSumSquareScoreDiff() + power2(diff));
}
this.results.put(this.columnIndexes[i], columnInfo);
}
if (this.recordCount % 1000 == 0) {
LOG.info("Finish to process {} records.", this.recordCount);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class AddColumnNumAndFilterUDF method exec.
@SuppressWarnings("deprecation")
@Override
public DataBag exec(Tuple input) throws IOException {
DataBag bag = BagFactory.getInstance().newDefaultBag();
TupleFactory tupleFactory = TupleFactory.getInstance();
if (input == null) {
return null;
}
int size = input.size();
if (size == 0 || input.size() != this.columnConfigList.size()) {
log.error("the input size - " + input.size() + ", while column size - " + columnConfigList.size());
this.mismatchCnt++;
// this could make Shifu could skip some malformed data
if (this.mismatchCnt > MAX_MISMATCH_CNT) {
throw new ShifuException(ShifuErrorCode.ERROR_NO_EQUAL_COLCONFIG);
}
return null;
}
if (input.get(tagColumnNum) == null) {
log.error("tagColumnNum is " + tagColumnNum + "; input size is " + input.size() + "; columnConfigList.size() is " + columnConfigList.size() + "; tuple is" + input.toDelimitedString("|") + "; tag is " + input.get(tagColumnNum));
if (isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG")) {
PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1);
}
return null;
}
String tag = CommonUtils.trimTag(input.get(tagColumnNum).toString());
if (this.isLinearTarget) {
if (!NumberUtils.isNumber(tag)) {
if (isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG")) {
PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1);
}
return null;
}
} else if (!super.tagSet.contains(tag)) {
if (isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG")) {
PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1);
}
return null;
}
Double rate = modelConfig.getBinningSampleRate();
if (!this.isLinearTarget && !modelConfig.isClassification() && modelConfig.isBinningSampleNegOnly()) {
if (super.negTagSet.contains(tag) && random.nextDouble() > rate) {
return null;
}
} else {
if (random.nextDouble() > rate) {
return null;
}
}
List<Boolean> filterResultList = null;
if (this.isForExpressions) {
filterResultList = new ArrayList<Boolean>();
for (int j = 0; j < this.dataPurifiers.size(); j++) {
DataPurifier dataPurifier = this.dataPurifiers.get(j);
filterResultList.add(dataPurifier.isFilter(input));
}
}
boolean isPositiveInst = (modelConfig.isRegression() && super.posTagSet.contains(tag));
for (int i = 0; i < size; i++) {
ColumnConfig config = columnConfigList.get(i);
if (!isValidRecord(modelConfig.isRegression(), isPositiveInst, config)) {
continue;
}
bag.add(buildTuple(input, tupleFactory, tag, i, i));
if (this.isForExpressions) {
for (int j = 0; j < this.dataPurifiers.size(); j++) {
Boolean isFilter = filterResultList.get(j);
if (isFilter != null && isFilter) {
bag.add(buildTuple(input, tupleFactory, tag, i, (j + 1) * size + i));
}
}
}
}
return bag;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CommonUtilsTest method getDerivedColumnNamesTest.
@Test
public void getDerivedColumnNamesTest() {
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig e = new ColumnConfig();
e.setColumnName("a");
list.add(e);
e = new ColumnConfig();
e.setColumnName("derived_c");
list.add(e);
e = new ColumnConfig();
e.setColumnName("d");
list.add(e);
List<String> output = CommonUtils.getDerivedColumnNames(list);
Assert.assertEquals(output.get(0), "derived_c");
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CommonUtilsTest method hasNoCandidateTest.
@Test
public void hasNoCandidateTest() {
List<ColumnConfig> configList = new ArrayList<ColumnConfig>();
ColumnConfig config = new ColumnConfig();
config.setColumnName("A");
config.setFinalSelect(false);
configList.add(config);
Assert.assertFalse(CommonUtils.hasCandidateColumns(configList));
}
Aggregations