use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class PostTrainModelProcessor method updateAvgScores.
private void updateAvgScores(SourceType source, String postTrainOutputPath) throws IOException {
List<Scanner> scanners = null;
try {
scanners = ShifuFileUtils.getDataScanners(postTrainOutputPath, source, new PathFilter() {
@Override
public boolean accept(Path path) {
return path.toString().contains("part-r-");
}
});
for (Scanner scanner : scanners) {
while (scanner.hasNextLine()) {
String line = scanner.nextLine().trim();
String[] keyValues = line.split("\t");
String key = keyValues[0];
String value = keyValues[1];
ColumnConfig config = this.columnConfigList.get(Integer.parseInt(key));
List<Integer> binAvgScores = new ArrayList<Integer>();
String[] avgScores = value.split(",");
for (int i = 0; i < avgScores.length; i++) {
binAvgScores.add(Integer.parseInt(avgScores[i]));
}
config.setBinAvgScore(binAvgScores);
}
}
} finally {
// release
closeScanners(scanners);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class TreeNodePmmlElementCreator method convert.
public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
pmmlNode.setId(String.valueOf(node.getId()));
if (node.getPredict() != null) {
pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
}
pmmlNode.setDefaultChild(null);
Predicate predicate = null;
ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
if (columnConfig.isNumerical()) {
SimplePredicate p = new SimplePredicate();
p.setValue(String.valueOf(split.getThreshold()));
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (isLeft) {
p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
} else {
p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
}
predicate = p;
} else if (columnConfig.isCategorical()) {
SimpleSetPredicate p = new SimpleSetPredicate();
Set<Short> childCategories = split.getLeftOrRightCategories();
// TODO, how to support segment variable in tree model, here should be changed
p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
StringBuilder arrayStr = new StringBuilder();
List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
for (Short sh : childCategories) {
if (sh >= valueList.size()) {
arrayStr.append(" \"\"");
continue;
}
String s = valueList.get(sh);
arrayStr.append(" ");
if (s.contains("\"")) {
String tmp = s.replaceAll("\"", "\\\\\\\"");
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(tmp);
arrayStr.append("\"");
} else {
arrayStr.append(tmp);
}
} else {
if (s.contains(" ")) {
arrayStr.append("\"");
arrayStr.append(s);
arrayStr.append("\"");
} else {
arrayStr.append(s);
}
}
}
Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
p.setArray(array);
if (isLeft) {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
}
} else {
if (split.isLeft()) {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
} else {
p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
}
}
predicate = p;
}
pmmlNode.setPredicate(predicate);
if (node.getSplit() == null || node.isRealLeaf()) {
return pmmlNode;
}
List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
childList.add(left);
childList.add(right);
return pmmlNode;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class PostTrainMapper method map.
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
// StringUtils.isBlank is not used here to avoid import new jar
if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
if (!this.dataPurifier.isFilter(valueStr)) {
return;
}
String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter());
// tagColumnNum should be in units array, if not IndexOutofBoundException
String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
if (!this.tags.contains(tag)) {
if (System.currentTimeMillis() % 20 == 0) {
LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag);
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
return;
}
Map<String, String> rawDataMap = buildRawDataMap(units);
CaseScoreResult csr = this.modelRunner.compute(rawDataMap);
// store score value
StringBuilder sb = new StringBuilder(500);
sb.append(csr.getAvgScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMaxScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMinScore()).append(Constants.DEFAULT_DELIMITER);
for (Double score : csr.getScores()) {
sb.append(score).append(Constants.DEFAULT_DELIMITER);
}
List<String> metaList = modelConfig.getMetaColumnNames();
for (String meta : metaList) {
sb.append(rawDataMap.get(meta)).append(Constants.DEFAULT_DELIMITER);
}
sb.deleteCharAt(sb.length() - Constants.DEFAULT_DELIMITER.length());
this.outputValue.set(sb.toString());
this.mos.write(Constants.POST_TRAIN_OUTPUT_SCORE, NullWritable.get(), this.outputValue);
for (int i = 0; i < headers.length; i++) {
ColumnConfig config = this.columnConfigList.get(i);
if (!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
int binNum = BinUtils.getBinNum(config, units[i]);
List<BinStats> feaureStatistics = this.variableStatsMap.get(config.getColumnNum());
BinStats bs = null;
if (binNum == -1) {
// if -1, means invalid numeric value like null or empty, last one is for empty stats.
bs = feaureStatistics.get(feaureStatistics.size() - 1);
} else {
bs = feaureStatistics.get(binNum);
}
// bs should not be null as already initialized in setup
bs.setBinSum(csr.getAvgScore() + bs.getBinSum());
bs.setBinCnt(1L + bs.getBinCnt());
}
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ModelStatsCreator method build.
@Override
public ModelStats build(BasicML basicML) {
ModelStats modelStats = new ModelStats();
if (basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion
// as we need to address new stats variable
// set simple column name in PMML
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect()) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion as we need to address new stats
// variable
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
}
return modelStats;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class FeatureImportanceMapper method map.
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
// StringUtils.isBlank is not used here to avoid import new jar
if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
if (!this.dataPurifier.isFilter(valueStr)) {
return;
}
String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter());
// tagColumnNum should be in units array, if not IndexOutofBoundException
String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
if (!this.tags.contains(tag)) {
if (System.currentTimeMillis() % 20 == 0) {
LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag);
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
return;
}
List<FeatureScore> featureScores = new ArrayList<FeatureImportanceMapper.FeatureScore>();
for (int i = 0; i < headers.length; i++) {
ColumnConfig config = this.columnConfigList.get(i);
if (!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
int binNum = BinUtils.getBinNum(config, units[i]);
List<Integer> binAvgScores = config.getBinAvgScore();
int binScore = 0;
if (binNum == -1) {
binScore = binAvgScores.get(binAvgScores.size() - 1);
} else {
binScore = binAvgScores.get(binNum);
}
featureScores.add(new FeatureScore(config.getColumnNum(), binScore));
}
}
Collections.sort(featureScores, new Comparator<FeatureScore>() {
@Override
public int compare(FeatureScore fs1, FeatureScore fs2) {
if (fs1.binAvgScore < fs2.binAvgScore) {
return 1;
}
if (fs1.binAvgScore > fs2.binAvgScore) {
return -1;
}
return 0;
}
});
int size = featureScores.size() >= 3 ? 3 : featureScores.size();
for (int i = 0; i < size; i++) {
FeatureScore featureScore = featureScores.get(i);
Double currValue = this.variableStatsMap.get(featureScore.columnNum);
currValue += size - i;
this.variableStatsMap.put(featureScore.columnNum, currValue);
}
}
Aggregations