Search in sources :

Example 91 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTrainUtils method getNumericalIds.

public static List<Integer> getNumericalIds(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
    List<Integer> numericalIds = new ArrayList<>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (isAfterVarSelect) {
            if (config.isNumerical() && config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
        } else {
            if (config.isNumerical() && !config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
    return numericalIds;
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 92 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTrainUtils method getCategoricalIds.

public static List<Integer> getCategoricalIds(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
    List<Integer> results = new ArrayList<>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (isAfterVarSelect) {
            if (config.isFinalSelect() && !config.isTarget() && !config.isMeta() && config.isCategorical()) {
        } else {
            if (!config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates) && config.isCategorical()) {
    return results;
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 93 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTMaster method doCompute.

public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) {
    if (context.isFirstIteration()) {
        return buildInitialMasterParams();
    if (this.cpMasterParams != null) {
        DTMasterParams tmpMasterParams = rebuildRecoverMasterResultDepthList();
        // set it to null to avoid send it in next iteration
        this.cpMasterParams = null;
        if (this.isGBDT) {
            // don't need to send full trees because worker will get existing models from HDFS
            // only set last tree to do node stats, no need check switch to next tree because of message may be send
            // to worker already
            tmpMasterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
            // set tmp trees for DTOutput
        return tmpMasterParams;
    boolean isFirst = false;
    Map<Integer, NodeStats> nodeStatsMap = null;
    double trainError = 0d, validationError = 0d;
    double weightedTrainCount = 0d, weightedValidationCount = 0d;
    for (DTWorkerParams params : context.getWorkerResults()) {
        if (!isFirst) {
            isFirst = true;
            nodeStatsMap = params.getNodeStatsMap();
        } else {
            Map<Integer, NodeStats> currNodeStatsmap = params.getNodeStatsMap();
            for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
                NodeStats resultNodeStats = entry.getValue();
                mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
            // set to null after merging, release memory at the earliest stage
        trainError += params.getTrainError();
        validationError += params.getValidationError();
        weightedTrainCount += params.getTrainCount();
        weightedValidationCount += params.getValidationCount();
    for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
        NodeStats nodeStats = entry.getValue();
        int treeId = nodeStats.getTreeId();
        Node doneNode = Node.getNode(trees.get(treeId).getNode(), nodeStats.getNodeId());
        // doneNode, NodeStats
        Map<Integer, double[]> statistics = nodeStats.getFeatureStatistics();
        List<GainInfo> gainList = new ArrayList<GainInfo>();
        for (Entry<Integer, double[]> gainEntry : statistics.entrySet()) {
            int columnNum = gainEntry.getKey();
            ColumnConfig config = this.columnConfigList.get(columnNum);
            double[] statsArray = gainEntry.getValue();
            GainInfo gainInfo = this.impurity.computeImpurity(statsArray, config);
            if (gainInfo != null) {
        GainInfo maxGainInfo = GainInfo.getGainInfoByMaxGain(gainList);
        if (maxGainInfo == null) {
            // null gain info, set to leaf and continue next stats
        populateGainInfoToNode(treeId, doneNode, maxGainInfo);
        if (this.isLeafWise) {
            boolean isNotSplit = maxGainInfo.getGain() <= 0d;
            if (!isNotSplit) {
                this.toSplitQueue.offer(new TreeNode(treeId, doneNode));
            } else {
      "Node {} in tree {} is not to be split", doneNode.getId(), treeId);
        } else {
            boolean isLeaf = maxGainInfo.getGain() <= 0d || Node.indexToLevel(doneNode.getId()) == this.maxDepth;
            // level-wise is to split node when stats is ready
            splitNodeForLevelWisedTree(isLeaf, treeId, doneNode);
    if (this.isLeafWise) {
        // get node in toSplitQueue and split
        int currSplitIndex = 0;
        while (!toSplitQueue.isEmpty() && currSplitIndex < this.maxBatchSplitSize) {
            TreeNode treeNode = this.toSplitQueue.poll();
            splitNodeForLeafWisedTree(treeNode.getTreeId(), treeNode.getNode());
    Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>();
    double averageValidationError = validationError / weightedValidationCount;
    if (this.isGBDT && this.dtEarlyStopDecider != null && averageValidationError > 0) {
        averageValidationError = this.dtEarlyStopDecider.getCurrentAverageValue();
    boolean vtTriggered = false;
    // if validationTolerance == 0d, means vt check is not enabled
    if (validationTolerance > 0d && Math.abs(this.bestValidationError - averageValidationError) < this.validationTolerance * averageValidationError) {
        LOG.debug("Debug: bestValidationError {}, averageValidationError {}, validationTolerance {}", bestValidationError, averageValidationError, validationTolerance);
        vtTriggered = true;
    if (averageValidationError < this.bestValidationError) {
        this.bestValidationError = averageValidationError;
    // validation error is averageValidationError * weightedValidationCount because of here averageValidationError
    // is divided by validation count.
    DTMasterParams masterParams = new DTMasterParams(weightedTrainCount, trainError, weightedValidationCount, averageValidationError * weightedValidationCount);
    if (toDoQueue.isEmpty()) {
        if (this.isGBDT) {
            TreeNode treeNode = this.trees.get(this.trees.size() - 1);
            Node node = treeNode.getNode();
            if (this.trees.size() >= this.treeNum) {
                // if all trees including trees read from existing model over treeNum, stop the whole process.
      "Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (node.getLeft() == null && node.getRight() == null) {
                // if very good performance, here can be some issues, say you'd like to get 5 trees, but in the 2nd
                // tree, you get one perfect tree, no need continue but warn users about such issue: set
                // BaggingSampleRate not to 1 can solve such issue to avoid overfit
                LOG.warn("Tree is learned 100% well, there must be overfit here, please tune BaggingSampleRate, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (this.dtEarlyStopDecider != null && (this.enableEarlyStop && this.dtEarlyStopDecider.canStop())) {
      "Early stop identified, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (vtTriggered) {
      "Early stop training by validation tolerance.");
            } else {
                // set first tree to true even after ROOT node is set in next tree
                masterParams.setFirstTree(this.trees.size() == 1);
                // finish current tree, no need features information
                TreeNode newRootNode = new TreeNode(this.trees.size(), new Node(Node.ROOT_INDEX), this.learningRate);
      "The {} tree is to be built.", this.trees.size());
                newRootNode.setFeatures(getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate));
                // only one node
                todoNodes.put(0, newRootNode);
                // set switch flag
        } else {
            // for rf
  "Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
    } else {
        int nodeIndexInGroup = 0;
        long currMem = 0L;
        List<Integer> depthList = new ArrayList<Integer>();
        if (this.isGBDT) {
        if (isRF) {
            for (int i = 0; i < this.trees.size(); i++) {
        while (!toDoQueue.isEmpty() && currMem <= this.maxStatsMemory) {
            TreeNode node = this.toDoQueue.poll();
            int treeId = node.getTreeId();
            int oldDepth = this.isGBDT ? depthList.get(0) : depthList.get(treeId);
            int currDepth = Node.indexToLevel(node.getNode().getId());
            if (currDepth > oldDepth) {
                if (this.isGBDT) {
                    // gbdt only for last depth
                    depthList.set(0, currDepth);
                } else {
                    depthList.set(treeId, currDepth);
            List<Integer> subsetFeatures = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate);
            currMem += getStatsMem(subsetFeatures);
            todoNodes.put(nodeIndexInGroup, node);
            nodeIndexInGroup += 1;
        masterParams.setFirstTree(this.trees.size() == 1);"Todo node size is {}", todoNodes.size());
    if (this.isGBDT) {
        if (masterParams.isSwitchToNextTree()) {
            // send last full growth tree and current todo ROOT node tree
            masterParams.setTrees(trees.subList(trees.size() - 2, trees.size()));
        } else {
            // only send current trees
            masterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
    if (this.isRF) {
        // elements in todoTrees are also the same reference in this.trees, reuse the same object to save memory.
        if (masterParams.getTreeDepth().size() == this.trees.size()) {
            // if normal iteration
            List<TreeNode> todoTrees = new ArrayList<TreeNode>();
            for (int i = 0; i < trees.size(); i++) {
                if (masterParams.getTreeDepth().get(i) >= 0) {
                    // such tree in current iteration treeDepth is not -1, add it to todoTrees.
                } else {
                    // mock a TreeNode instance to make sure no surprise in further serialization. In fact
                    // meaningless.
                    todoTrees.add(new TreeNode(i, new Node(Node.INVALID_INDEX), 1d));
        } else {
            // if last iteration without maxDepthList
    if (this.isGBDT) {
        // set tmp trees to DTOutput
    if (context.getCurrentIteration() % 100 == 0) {
        // every 100 iterations do gc explicitly to avoid one case:
        // mapper memory is 2048M and final in our cluster, if -Xmx is 2G, then occasionally oom issue.
        // to fix this issue: 1. set -Xmx to 1800m; 2. call gc to drop unused memory at early stage.
        // this is ugly and if it is stable with 1800m, this line should be removed
        Thread gcThread = new Thread(new Runnable() {

            public void run() {
    // before master result, do checkpoint according to n iteration set by user
    doCheckPoint(context, masterParams, context.getCurrentIteration());
    LOG.debug("weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", weightedTrainCount, weightedValidationCount, trainError, validationError);
    return masterParams;
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) NodeStats(ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats)

Example 94 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class CorrelationMapper method setup.

protected void setup(Context context) throws IOException, InterruptedException {
    this.dataSetDelimiter = modelConfig.getDataSetDelimiter();
    this.dataPurifier = new DataPurifier(modelConfig, false);
    this.isComputeAll = Boolean.valueOf(context.getConfiguration().get(Constants.SHIFU_CORRELATION_COMPUTE_ALL, "false"));
    for (ColumnConfig config : columnConfigList) {
        if (config.isCategorical()) {
            Map<String, Integer> map = new HashMap<String, Integer>();
            if (config.getBinCategory() != null) {
                for (int i = 0; i < config.getBinCategory().size(); i++) {
                    List<String> cvals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                    for (String cval : cvals) {
                        map.put(cval, i);
            this.categoricalIndexMap.put(config.getColumnNum(), map);
    if (modelConfig != null && modelConfig.getPosTags() != null) {
        this.posTagSet = new HashSet<String>(modelConfig.getPosTags());
    if (modelConfig != null && modelConfig.getNegTags() != null) {
        this.negTagSet = new HashSet<String>(modelConfig.getNegTags());
    if (modelConfig != null && modelConfig.getFlattenTags() != null) {
        this.tagSet = new HashSet<String>(modelConfig.getFlattenTags());
    if (modelConfig != null) {
        this.tags = modelConfig.getSetTags();
Also used : DataPurifier(ml.shifu.shifu.core.DataPurifier) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap)

Example 95 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class PostTrainModelProcessor method updateColumnConfigWithBinAvgScore.

 * read the binary average score and update them into column list
 * @param columnConfigList
 *            input column config list
 * @return updated column config list
 * @throws IOException
 *             for any io exception
private List<ColumnConfig> updateColumnConfigWithBinAvgScore(List<ColumnConfig> columnConfigList) throws IOException {
    List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource());
    // CommonUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource());
    for (Scanner scanner : scanners) {
        while (scanner.hasNextLine()) {
            List<Integer> scores = new ArrayList<Integer>();
            String[] raw = scanner.nextLine().split("\\|");
            int columnNum = Integer.parseInt(raw[0]);
            for (int i = 1; i < raw.length; i++) {
            ColumnConfig config = columnConfigList.get(columnNum);
    // release
    return columnConfigList;
Also used : Scanner(java.util.Scanner) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList)


ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException ( HashMap (java.util.HashMap)12 Tuple ( File ( NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag ( SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData ( BufferedWriter ( FileInputStream (