Search in sources :

Example 36 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class PostTrainModelProcessor method updateAvgScores.

private void updateAvgScores(SourceType source, String postTrainOutputPath) throws IOException {
    List<Scanner> scanners = null;
    try {
        scanners = ShifuFileUtils.getDataScanners(postTrainOutputPath, source, new PathFilter() {

            @Override
            public boolean accept(Path path) {
                return path.toString().contains("part-r-");
            }
        });
        for (Scanner scanner : scanners) {
            while (scanner.hasNextLine()) {
                String line = scanner.nextLine().trim();
                String[] keyValues = line.split("\t");
                String key = keyValues[0];
                String value = keyValues[1];
                ColumnConfig config = this.columnConfigList.get(Integer.parseInt(key));
                List<Integer> binAvgScores = new ArrayList<Integer>();
                String[] avgScores = value.split(",");
                for (int i = 0; i < avgScores.length; i++) {
                    binAvgScores.add(Integer.parseInt(avgScores[i]));
                }
                config.setBinAvgScore(binAvgScores);
            }
        }
    } finally {
        // release
        closeScanners(scanners);
    }
}
Also used : Path(org.apache.hadoop.fs.Path) Scanner(java.util.Scanner) PathFilter(org.apache.hadoop.fs.PathFilter) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList)

Example 37 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class TreeNodePmmlElementCreator method convert.

public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
    org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
    pmmlNode.setId(String.valueOf(node.getId()));
    if (node.getPredict() != null) {
        pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
    }
    pmmlNode.setDefaultChild(null);
    Predicate predicate = null;
    ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
    if (columnConfig.isNumerical()) {
        SimplePredicate p = new SimplePredicate();
        p.setValue(String.valueOf(split.getThreshold()));
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        if (isLeft) {
            p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
        } else {
            p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
        }
        predicate = p;
    } else if (columnConfig.isCategorical()) {
        SimpleSetPredicate p = new SimpleSetPredicate();
        Set<Short> childCategories = split.getLeftOrRightCategories();
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        StringBuilder arrayStr = new StringBuilder();
        List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
        for (Short sh : childCategories) {
            if (sh >= valueList.size()) {
                arrayStr.append(" \"\"");
                continue;
            }
            String s = valueList.get(sh);
            arrayStr.append(" ");
            if (s.contains("\"")) {
                String tmp = s.replaceAll("\"", "\\\\\\\"");
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(tmp);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(tmp);
                }
            } else {
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(s);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(s);
                }
            }
        }
        Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
        p.setArray(array);
        if (isLeft) {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            }
        } else {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            }
        }
        predicate = p;
    }
    pmmlNode.setPredicate(predicate);
    if (node.getSplit() == null || node.isRealLeaf()) {
        return pmmlNode;
    }
    List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
    org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
    org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
    childList.add(left);
    childList.add(right);
    return pmmlNode;
}
Also used : Set(java.util.Set) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Node(ml.shifu.shifu.core.dtrain.dt.Node) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) Array(org.dmg.pmml.Array) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 38 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class PostTrainMapper method map.

@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
    String valueStr = value.toString();
    // StringUtils.isBlank is not used here to avoid import new jar
    if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
        LOG.warn("Empty input.");
        return;
    }
    if (!this.dataPurifier.isFilter(valueStr)) {
        return;
    }
    String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter());
    // tagColumnNum should be in units array, if not IndexOutofBoundException
    String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
    if (!this.tags.contains(tag)) {
        if (System.currentTimeMillis() % 20 == 0) {
            LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag);
        }
        context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
        return;
    }
    Map<String, String> rawDataMap = buildRawDataMap(units);
    CaseScoreResult csr = this.modelRunner.compute(rawDataMap);
    // store score value
    StringBuilder sb = new StringBuilder(500);
    sb.append(csr.getAvgScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMaxScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMinScore()).append(Constants.DEFAULT_DELIMITER);
    for (Double score : csr.getScores()) {
        sb.append(score).append(Constants.DEFAULT_DELIMITER);
    }
    List<String> metaList = modelConfig.getMetaColumnNames();
    for (String meta : metaList) {
        sb.append(rawDataMap.get(meta)).append(Constants.DEFAULT_DELIMITER);
    }
    sb.deleteCharAt(sb.length() - Constants.DEFAULT_DELIMITER.length());
    this.outputValue.set(sb.toString());
    this.mos.write(Constants.POST_TRAIN_OUTPUT_SCORE, NullWritable.get(), this.outputValue);
    for (int i = 0; i < headers.length; i++) {
        ColumnConfig config = this.columnConfigList.get(i);
        if (!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
            int binNum = BinUtils.getBinNum(config, units[i]);
            List<BinStats> feaureStatistics = this.variableStatsMap.get(config.getColumnNum());
            BinStats bs = null;
            if (binNum == -1) {
                // if -1, means invalid numeric value like null or empty, last one is for empty stats.
                bs = feaureStatistics.get(feaureStatistics.size() - 1);
            } else {
                bs = feaureStatistics.get(binNum);
            }
            // bs should not be null as already initialized in setup
            bs.setBinSum(csr.getAvgScore() + bs.getBinSum());
            bs.setBinCnt(1L + bs.getBinCnt());
        }
    }
}
Also used : CaseScoreResult(ml.shifu.shifu.container.CaseScoreResult) BinStats(ml.shifu.shifu.core.posttrain.FeatureStatsWritable.BinStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 39 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ModelStatsCreator method build.

@Override
public ModelStats build(BasicML basicML) {
    ModelStats modelStats = new ModelStats();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion
                // as we need to address new stats variable
                // set simple column name in PMML
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect()) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion as we need to address new stats
                // variable
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    }
    return modelStats;
}
Also used : Array(org.dmg.pmml.Array) Extension(org.dmg.pmml.Extension) DiscrStats(org.dmg.pmml.DiscrStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) UnivariateStats(org.dmg.pmml.UnivariateStats) ModelStats(org.dmg.pmml.ModelStats) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 40 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class FeatureImportanceMapper method map.

@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
    String valueStr = value.toString();
    // StringUtils.isBlank is not used here to avoid import new jar
    if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
        LOG.warn("Empty input.");
        return;
    }
    if (!this.dataPurifier.isFilter(valueStr)) {
        return;
    }
    String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter());
    // tagColumnNum should be in units array, if not IndexOutofBoundException
    String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
    if (!this.tags.contains(tag)) {
        if (System.currentTimeMillis() % 20 == 0) {
            LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag);
        }
        context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
        return;
    }
    List<FeatureScore> featureScores = new ArrayList<FeatureImportanceMapper.FeatureScore>();
    for (int i = 0; i < headers.length; i++) {
        ColumnConfig config = this.columnConfigList.get(i);
        if (!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
            int binNum = BinUtils.getBinNum(config, units[i]);
            List<Integer> binAvgScores = config.getBinAvgScore();
            int binScore = 0;
            if (binNum == -1) {
                binScore = binAvgScores.get(binAvgScores.size() - 1);
            } else {
                binScore = binAvgScores.get(binNum);
            }
            featureScores.add(new FeatureScore(config.getColumnNum(), binScore));
        }
    }
    Collections.sort(featureScores, new Comparator<FeatureScore>() {

        @Override
        public int compare(FeatureScore fs1, FeatureScore fs2) {
            if (fs1.binAvgScore < fs2.binAvgScore) {
                return 1;
            }
            if (fs1.binAvgScore > fs2.binAvgScore) {
                return -1;
            }
            return 0;
        }
    });
    int size = featureScores.size() >= 3 ? 3 : featureScores.size();
    for (int i = 0; i < size; i++) {
        FeatureScore featureScore = featureScores.get(i);
        Double currValue = this.variableStatsMap.get(featureScore.columnNum);
        currValue += size - i;
        this.variableStatsMap.put(featureScore.columnNum, currValue);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3