Search in sources :

Example 91 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTrainUtils method getNumericalIds.

public static List<Integer> getNumericalIds(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
    List<Integer> numericalIds = new ArrayList<>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (isAfterVarSelect) {
            if (config.isNumerical() && config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
                numericalIds.add(config.getColumnNum());
            }
        } else {
            if (config.isNumerical() && !config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
                numericalIds.add(config.getColumnNum());
            }
        }
    }
    return numericalIds;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 92 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTrainUtils method getCategoricalIds.

public static List<Integer> getCategoricalIds(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
    List<Integer> results = new ArrayList<>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (isAfterVarSelect) {
            if (config.isFinalSelect() && !config.isTarget() && !config.isMeta() && config.isCategorical()) {
                results.add(config.getColumnNum());
            }
        } else {
            if (!config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates) && config.isCategorical()) {
                results.add(config.getColumnNum());
            }
        }
    }
    return results;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 93 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTMaster method doCompute.

@Override
public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) {
    if (context.isFirstIteration()) {
        return buildInitialMasterParams();
    }
    if (this.cpMasterParams != null) {
        DTMasterParams tmpMasterParams = rebuildRecoverMasterResultDepthList();
        // set it to null to avoid send it in next iteration
        this.cpMasterParams = null;
        if (this.isGBDT) {
            // don't need to send full trees because worker will get existing models from HDFS
            // only set last tree to do node stats, no need check switch to next tree because of message may be send
            // to worker already
            tmpMasterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
            // set tmp trees for DTOutput
            tmpMasterParams.setTmpTrees(this.trees);
        }
        return tmpMasterParams;
    }
    boolean isFirst = false;
    Map<Integer, NodeStats> nodeStatsMap = null;
    double trainError = 0d, validationError = 0d;
    double weightedTrainCount = 0d, weightedValidationCount = 0d;
    for (DTWorkerParams params : context.getWorkerResults()) {
        if (!isFirst) {
            isFirst = true;
            nodeStatsMap = params.getNodeStatsMap();
        } else {
            Map<Integer, NodeStats> currNodeStatsmap = params.getNodeStatsMap();
            for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
                NodeStats resultNodeStats = entry.getValue();
                mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
            }
            // set to null after merging, release memory at the earliest stage
            params.setNodeStatsMap(null);
        }
        trainError += params.getTrainError();
        validationError += params.getValidationError();
        weightedTrainCount += params.getTrainCount();
        weightedValidationCount += params.getValidationCount();
    }
    for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
        NodeStats nodeStats = entry.getValue();
        int treeId = nodeStats.getTreeId();
        Node doneNode = Node.getNode(trees.get(treeId).getNode(), nodeStats.getNodeId());
        // doneNode, NodeStats
        Map<Integer, double[]> statistics = nodeStats.getFeatureStatistics();
        List<GainInfo> gainList = new ArrayList<GainInfo>();
        for (Entry<Integer, double[]> gainEntry : statistics.entrySet()) {
            int columnNum = gainEntry.getKey();
            ColumnConfig config = this.columnConfigList.get(columnNum);
            double[] statsArray = gainEntry.getValue();
            GainInfo gainInfo = this.impurity.computeImpurity(statsArray, config);
            if (gainInfo != null) {
                gainList.add(gainInfo);
            }
        }
        GainInfo maxGainInfo = GainInfo.getGainInfoByMaxGain(gainList);
        if (maxGainInfo == null) {
            // null gain info, set to leaf and continue next stats
            doneNode.setLeaf(true);
            continue;
        }
        populateGainInfoToNode(treeId, doneNode, maxGainInfo);
        if (this.isLeafWise) {
            boolean isNotSplit = maxGainInfo.getGain() <= 0d;
            if (!isNotSplit) {
                this.toSplitQueue.offer(new TreeNode(treeId, doneNode));
            } else {
                LOG.info("Node {} in tree {} is not to be split", doneNode.getId(), treeId);
            }
        } else {
            boolean isLeaf = maxGainInfo.getGain() <= 0d || Node.indexToLevel(doneNode.getId()) == this.maxDepth;
            doneNode.setLeaf(isLeaf);
            // level-wise is to split node when stats is ready
            splitNodeForLevelWisedTree(isLeaf, treeId, doneNode);
        }
    }
    if (this.isLeafWise) {
        // get node in toSplitQueue and split
        int currSplitIndex = 0;
        while (!toSplitQueue.isEmpty() && currSplitIndex < this.maxBatchSplitSize) {
            TreeNode treeNode = this.toSplitQueue.poll();
            splitNodeForLeafWisedTree(treeNode.getTreeId(), treeNode.getNode());
        }
    }
    Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>();
    double averageValidationError = validationError / weightedValidationCount;
    if (this.isGBDT && this.dtEarlyStopDecider != null && averageValidationError > 0) {
        this.dtEarlyStopDecider.add(averageValidationError);
        averageValidationError = this.dtEarlyStopDecider.getCurrentAverageValue();
    }
    boolean vtTriggered = false;
    // if validationTolerance == 0d, means vt check is not enabled
    if (validationTolerance > 0d && Math.abs(this.bestValidationError - averageValidationError) < this.validationTolerance * averageValidationError) {
        LOG.debug("Debug: bestValidationError {}, averageValidationError {}, validationTolerance {}", bestValidationError, averageValidationError, validationTolerance);
        vtTriggered = true;
    }
    if (averageValidationError < this.bestValidationError) {
        this.bestValidationError = averageValidationError;
    }
    // validation error is averageValidationError * weightedValidationCount because of here averageValidationError
    // is divided by validation count.
    DTMasterParams masterParams = new DTMasterParams(weightedTrainCount, trainError, weightedValidationCount, averageValidationError * weightedValidationCount);
    if (toDoQueue.isEmpty()) {
        if (this.isGBDT) {
            TreeNode treeNode = this.trees.get(this.trees.size() - 1);
            Node node = treeNode.getNode();
            if (this.trees.size() >= this.treeNum) {
                // if all trees including trees read from existing model over treeNum, stop the whole process.
                masterParams.setHalt(true);
                LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (node.getLeft() == null && node.getRight() == null) {
                // if very good performance, here can be some issues, say you'd like to get 5 trees, but in the 2nd
                // tree, you get one perfect tree, no need continue but warn users about such issue: set
                // BaggingSampleRate not to 1 can solve such issue to avoid overfit
                masterParams.setHalt(true);
                LOG.warn("Tree is learned 100% well, there must be overfit here, please tune BaggingSampleRate, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (this.dtEarlyStopDecider != null && (this.enableEarlyStop && this.dtEarlyStopDecider.canStop())) {
                masterParams.setHalt(true);
                LOG.info("Early stop identified, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (vtTriggered) {
                LOG.info("Early stop training by validation tolerance.");
                masterParams.setHalt(true);
            } else {
                // set first tree to true even after ROOT node is set in next tree
                masterParams.setFirstTree(this.trees.size() == 1);
                // finish current tree, no need features information
                treeNode.setFeatures(null);
                TreeNode newRootNode = new TreeNode(this.trees.size(), new Node(Node.ROOT_INDEX), this.learningRate);
                LOG.info("The {} tree is to be built.", this.trees.size());
                this.trees.add(newRootNode);
                newRootNode.setFeatures(getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate));
                // only one node
                todoNodes.put(0, newRootNode);
                masterParams.setTodoNodes(todoNodes);
                // set switch flag
                masterParams.setSwitchToNextTree(true);
            }
        } else {
            // for rf
            masterParams.setHalt(true);
            LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
        }
    } else {
        int nodeIndexInGroup = 0;
        long currMem = 0L;
        List<Integer> depthList = new ArrayList<Integer>();
        if (this.isGBDT) {
            depthList.add(-1);
        }
        if (isRF) {
            for (int i = 0; i < this.trees.size(); i++) {
                depthList.add(-1);
            }
        }
        while (!toDoQueue.isEmpty() && currMem <= this.maxStatsMemory) {
            TreeNode node = this.toDoQueue.poll();
            int treeId = node.getTreeId();
            int oldDepth = this.isGBDT ? depthList.get(0) : depthList.get(treeId);
            int currDepth = Node.indexToLevel(node.getNode().getId());
            if (currDepth > oldDepth) {
                if (this.isGBDT) {
                    // gbdt only for last depth
                    depthList.set(0, currDepth);
                } else {
                    depthList.set(treeId, currDepth);
                }
            }
            List<Integer> subsetFeatures = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate);
            node.setFeatures(subsetFeatures);
            currMem += getStatsMem(subsetFeatures);
            todoNodes.put(nodeIndexInGroup, node);
            nodeIndexInGroup += 1;
        }
        masterParams.setTreeDepth(depthList);
        masterParams.setTodoNodes(todoNodes);
        masterParams.setSwitchToNextTree(false);
        masterParams.setContinuousRunningStart(false);
        masterParams.setFirstTree(this.trees.size() == 1);
        LOG.info("Todo node size is {}", todoNodes.size());
    }
    if (this.isGBDT) {
        if (masterParams.isSwitchToNextTree()) {
            // send last full growth tree and current todo ROOT node tree
            masterParams.setTrees(trees.subList(trees.size() - 2, trees.size()));
        } else {
            // only send current trees
            masterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
        }
    }
    if (this.isRF) {
        // elements in todoTrees are also the same reference in this.trees, reuse the same object to save memory.
        if (masterParams.getTreeDepth().size() == this.trees.size()) {
            // if normal iteration
            List<TreeNode> todoTrees = new ArrayList<TreeNode>();
            for (int i = 0; i < trees.size(); i++) {
                if (masterParams.getTreeDepth().get(i) >= 0) {
                    // such tree in current iteration treeDepth is not -1, add it to todoTrees.
                    todoTrees.add(trees.get(i));
                } else {
                    // mock a TreeNode instance to make sure no surprise in further serialization. In fact
                    // meaningless.
                    todoTrees.add(new TreeNode(i, new Node(Node.INVALID_INDEX), 1d));
                }
            }
            masterParams.setTrees(todoTrees);
        } else {
            // if last iteration without maxDepthList
            masterParams.setTrees(trees);
        }
    }
    if (this.isGBDT) {
        // set tmp trees to DTOutput
        masterParams.setTmpTrees(this.trees);
    }
    if (context.getCurrentIteration() % 100 == 0) {
        // every 100 iterations do gc explicitly to avoid one case:
        // mapper memory is 2048M and final in our cluster, if -Xmx is 2G, then occasionally oom issue.
        // to fix this issue: 1. set -Xmx to 1800m; 2. call gc to drop unused memory at early stage.
        // this is ugly and if it is stable with 1800m, this line should be removed
        Thread gcThread = new Thread(new Runnable() {

            @Override
            public void run() {
                System.gc();
            }
        });
        gcThread.setDaemon(true);
        gcThread.start();
    }
    // before master result, do checkpoint according to n iteration set by user
    doCheckPoint(context, masterParams, context.getCurrentIteration());
    LOG.debug("weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", weightedTrainCount, weightedValidationCount, trainError, validationError);
    return masterParams;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) NodeStats(ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats)

Example 94 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class CorrelationMapper method setup.

@Override
protected void setup(Context context) throws IOException, InterruptedException {
    loadConfigFiles(context);
    this.dataSetDelimiter = modelConfig.getDataSetDelimiter();
    this.dataPurifier = new DataPurifier(modelConfig, false);
    this.isComputeAll = Boolean.valueOf(context.getConfiguration().get(Constants.SHIFU_CORRELATION_COMPUTE_ALL, "false"));
    for (ColumnConfig config : columnConfigList) {
        if (config.isCategorical()) {
            Map<String, Integer> map = new HashMap<String, Integer>();
            if (config.getBinCategory() != null) {
                for (int i = 0; i < config.getBinCategory().size(); i++) {
                    List<String> cvals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                    for (String cval : cvals) {
                        map.put(cval, i);
                    }
                }
            }
            this.categoricalIndexMap.put(config.getColumnNum(), map);
        }
    }
    if (modelConfig != null && modelConfig.getPosTags() != null) {
        this.posTagSet = new HashSet<String>(modelConfig.getPosTags());
    }
    if (modelConfig != null && modelConfig.getNegTags() != null) {
        this.negTagSet = new HashSet<String>(modelConfig.getNegTags());
    }
    if (modelConfig != null && modelConfig.getFlattenTags() != null) {
        this.tagSet = new HashSet<String>(modelConfig.getFlattenTags());
    }
    if (modelConfig != null) {
        this.tags = modelConfig.getSetTags();
    }
}
Also used : DataPurifier(ml.shifu.shifu.core.DataPurifier) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap)

Example 95 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class PostTrainModelProcessor method updateColumnConfigWithBinAvgScore.

/**
 * read the binary average score and update them into column list
 *
 * @param columnConfigList
 *            input column config list
 * @return updated column config list
 * @throws IOException
 *             for any io exception
 */
private List<ColumnConfig> updateColumnConfigWithBinAvgScore(List<ColumnConfig> columnConfigList) throws IOException {
    List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource());
    // CommonUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource());
    for (Scanner scanner : scanners) {
        while (scanner.hasNextLine()) {
            List<Integer> scores = new ArrayList<Integer>();
            String[] raw = scanner.nextLine().split("\\|");
            int columnNum = Integer.parseInt(raw[0]);
            for (int i = 1; i < raw.length; i++) {
                scores.add(Integer.valueOf(raw[i]));
            }
            ColumnConfig config = columnConfigList.get(columnNum);
            config.setBinAvgScore(scores);
        }
    }
    // release
    closeScanners(scanners);
    return columnConfigList;
}
Also used : Scanner(java.util.Scanner) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3