Search in sources :

Example 1 with NodeStats

use of ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats in project shifu by ShifuML.

the class DTWorker method doCompute.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.guagua.worker.AbstractWorkerComputable#doCompute(ml.shifu.guagua.worker.WorkerContext)
     */
@Override
public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> context) {
    if (context.isFirstIteration()) {
        return new DTWorkerParams();
    }
    DTMasterParams lastMasterResult = context.getLastMasterResult();
    final List<TreeNode> trees = lastMasterResult.getTrees();
    final Map<Integer, TreeNode> todoNodes = lastMasterResult.getTodoNodes();
    if (todoNodes == null) {
        return new DTWorkerParams();
    }
    LOG.info("Start to work: todoNodes size is {}", todoNodes.size());
    Map<Integer, NodeStats> statistics = initTodoNodeStats(todoNodes);
    double trainError = 0d, validationError = 0d;
    double weightedTrainCount = 0d, weightedValidationCount = 0d;
    // renew random seed
    if (this.isGBDT && !this.gbdtSampleWithReplacement && lastMasterResult.isSwitchToNextTree()) {
        this.baggingRandomMap = new HashMap<Integer, Random>();
    }
    long start = System.nanoTime();
    for (Data data : this.trainingData) {
        if (this.isRF) {
            for (TreeNode treeNode : trees) {
                if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
                    continue;
                }
                Node predictNode = predictNodeIndex(treeNode.getNode(), data, true);
                if (predictNode.getPredict() != null) {
                    // only update when not in first node, for treeNode, no predict statistics at that time
                    float weight = data.subsampleWeights[treeNode.getTreeId()];
                    if (Float.compare(weight, 0f) == 0) {
                        // oob data, no need to do weighting
                        validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedValidationCount += data.significance;
                    } else {
                        trainError += weight * data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedTrainCount += weight * data.significance;
                    }
                }
            }
        }
        if (this.isGBDT) {
            if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) {
                recoverGBTData(context, data.output, data.predict, data, false);
                trainError += data.significance * loss.computeError(data.predict, data.label);
                weightedTrainCount += data.significance;
            } else {
                if (isNeedRecoverGBDTPredict) {
                    if (this.recoverTrees == null) {
                        this.recoverTrees = recoverCurrentTrees();
                    }
                    // recover gbdt data for fail over
                    recoverGBTData(context, data.output, data.predict, data, true);
                }
                int currTreeIndex = trees.size() - 1;
                if (lastMasterResult.isSwitchToNextTree()) {
                    if (currTreeIndex >= 1) {
                        Node node = trees.get(currTreeIndex - 1).getNode();
                        Node predictNode = predictNodeIndex(node, data, false);
                        if (predictNode.getPredict() != null) {
                            double predict = predictNode.getPredict().getPredict();
                            // sending
                            if (context.getLastMasterResult().isFirstTree()) {
                                data.predict = (float) predict;
                            } else {
                                // random drop
                                boolean drop = (this.dropOutRate > 0.0 && dropOutRandom.nextDouble() < this.dropOutRate);
                                if (!drop) {
                                    data.predict += (float) (this.learningRate * predict);
                                }
                            }
                            data.output = -1f * loss.computeGradient(data.predict, data.label);
                        }
                        // if not sampling with replacement in gbdt, renew bagging sample rate in next tree
                        if (!this.gbdtSampleWithReplacement) {
                            Random random = null;
                            int classValue = (int) (data.label + 0.01f);
                            if (this.isStratifiedSampling) {
                                random = baggingRandomMap.get(classValue);
                                if (random == null) {
                                    random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                                    baggingRandomMap.put(classValue, random);
                                }
                            } else {
                                random = baggingRandomMap.get(0);
                                if (random == null) {
                                    random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
                                    baggingRandomMap.put(0, random);
                                }
                            }
                            if (random.nextDouble() <= modelConfig.getTrain().getBaggingSampleRate()) {
                                data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 1f;
                            } else {
                                data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 0f;
                            }
                        }
                    }
                }
                if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
                    Node currTree = trees.get(currTreeIndex).getNode();
                    Node predictNode = predictNodeIndex(currTree, data, true);
                    if (predictNode.getPredict() != null) {
                        trainError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedTrainCount += data.significance;
                    }
                } else {
                    trainError += data.significance * loss.computeError(data.predict, data.label);
                    weightedTrainCount += data.significance;
                }
            }
        }
    }
    LOG.debug("Compute train error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
    if (validationData != null) {
        start = System.nanoTime();
        for (Data data : this.validationData) {
            if (this.isRF) {
                for (TreeNode treeNode : trees) {
                    if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
                        continue;
                    }
                    Node predictNode = predictNodeIndex(treeNode.getNode(), data, true);
                    if (predictNode.getPredict() != null) {
                        // only update when not in first node, for treeNode, no predict statistics at that time
                        validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
                        weightedValidationCount += data.significance;
                    }
                }
            }
            if (this.isGBDT) {
                if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) {
                    recoverGBTData(context, data.output, data.predict, data, false);
                    validationError += data.significance * loss.computeError(data.predict, data.label);
                    weightedValidationCount += data.significance;
                } else {
                    if (isNeedRecoverGBDTPredict) {
                        if (this.recoverTrees == null) {
                            this.recoverTrees = recoverCurrentTrees();
                        }
                        // recover gbdt data for fail over
                        recoverGBTData(context, data.output, data.predict, data, true);
                    }
                    int currTreeIndex = trees.size() - 1;
                    if (lastMasterResult.isSwitchToNextTree()) {
                        if (currTreeIndex >= 1) {
                            Node node = trees.get(currTreeIndex - 1).getNode();
                            Node predictNode = predictNodeIndex(node, data, false);
                            if (predictNode.getPredict() != null) {
                                double predict = predictNode.getPredict().getPredict();
                                if (context.getLastMasterResult().isFirstTree()) {
                                    data.predict = (float) predict;
                                } else {
                                    data.predict += (float) (this.learningRate * predict);
                                }
                                data.output = -1f * loss.computeGradient(data.predict, data.label);
                            }
                        }
                    }
                    if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
                        Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, true);
                        if (predictNode.getPredict() != null) {
                            validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
                            weightedValidationCount += data.significance;
                        }
                    } else {
                        validationError += data.significance * loss.computeError(data.predict, data.label);
                        weightedValidationCount += data.significance;
                    }
                }
            }
        }
        LOG.debug("Compute val error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
    }
    if (this.isGBDT) {
        // reset trees to null to save memory
        this.recoverTrees = null;
        if (this.isNeedRecoverGBDTPredict) {
            // no need recover again
            this.isNeedRecoverGBDTPredict = false;
        }
    }
    start = System.nanoTime();
    CompletionService<Map<Integer, NodeStats>> completionService = new ExecutorCompletionService<Map<Integer, NodeStats>>(this.threadPool);
    int realThreadCount = 0;
    LOG.debug("while todo size {}", todoNodes.size());
    int realRecords = this.trainingData.size();
    int realThreads = this.workerThreadCount > realRecords ? realRecords : this.workerThreadCount;
    int[] trainLows = new int[realThreads];
    int[] trainHighs = new int[realThreads];
    int stepCount = realRecords / realThreads;
    if (realRecords % realThreads != 0) {
        // move step count to append last gap to avoid last thread worse 2*stepCount-1
        stepCount += (realRecords % realThreads) / stepCount;
    }
    for (int i = 0; i < realThreads; i++) {
        trainLows[i] = i * stepCount;
        if (i != realThreads - 1) {
            trainHighs[i] = trainLows[i] + stepCount - 1;
        } else {
            trainHighs[i] = realRecords - 1;
        }
    }
    for (int i = 0; i < realThreads; i++) {
        final Map<Integer, TreeNode> localTodoNodes = new HashMap<Integer, TreeNode>(todoNodes);
        final Map<Integer, NodeStats> localStatistics = initTodoNodeStats(todoNodes);
        final int startIndex = trainLows[i];
        final int endIndex = trainHighs[i];
        LOG.info("Thread {} todo size {} stats size {} start index {} end index {}", i, localTodoNodes.size(), localStatistics.size(), startIndex, endIndex);
        if (localTodoNodes.size() == 0) {
            continue;
        }
        realThreadCount += 1;
        completionService.submit(new Callable<Map<Integer, NodeStats>>() {

            @Override
            public Map<Integer, NodeStats> call() throws Exception {
                long start = System.nanoTime();
                List<Integer> nodeIndexes = new ArrayList<Integer>(trees.size());
                for (int j = startIndex; j <= endIndex; j++) {
                    Data data = DTWorker.this.trainingData.get(j);
                    nodeIndexes.clear();
                    if (DTWorker.this.isRF) {
                        for (TreeNode treeNode : trees) {
                            if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
                                nodeIndexes.add(Node.INVALID_INDEX);
                            } else {
                                Node predictNode = predictNodeIndex(treeNode.getNode(), data, false);
                                nodeIndexes.add(predictNode.getId());
                            }
                        }
                    }
                    if (DTWorker.this.isGBDT) {
                        int currTreeIndex = trees.size() - 1;
                        Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, false);
                        // update node index
                        nodeIndexes.add(predictNode.getId());
                    }
                    for (Map.Entry<Integer, TreeNode> entry : localTodoNodes.entrySet()) {
                        // only do statistics on effective data
                        Node todoNode = entry.getValue().getNode();
                        int treeId = entry.getValue().getTreeId();
                        int currPredictIndex = 0;
                        if (DTWorker.this.isRF) {
                            currPredictIndex = nodeIndexes.get(entry.getValue().getTreeId());
                        }
                        if (DTWorker.this.isGBDT) {
                            currPredictIndex = nodeIndexes.get(0);
                        }
                        if (todoNode.getId() == currPredictIndex) {
                            List<Integer> features = entry.getValue().getFeatures();
                            if (features.isEmpty()) {
                                features = getAllValidFeatures();
                            }
                            for (Integer columnNum : features) {
                                double[] featuerStatistic = localStatistics.get(entry.getKey()).getFeatureStatistics().get(columnNum);
                                float weight = data.subsampleWeights[treeId % data.subsampleWeights.length];
                                if (Float.compare(weight, 0f) != 0) {
                                    // only compute weight is not 0
                                    short binIndex = data.inputs[DTWorker.this.inputIndexMap.get(columnNum)];
                                    DTWorker.this.impurity.featureUpdate(featuerStatistic, binIndex, data.output, data.significance, weight);
                                }
                            }
                        }
                    }
                }
                LOG.debug("Thread computing stats time is {}ms in thread {}", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start), Thread.currentThread().getName());
                return localStatistics;
            }
        });
    }
    int rCnt = 0;
    while (rCnt < realThreadCount) {
        try {
            Map<Integer, NodeStats> currNodeStatsmap = completionService.take().get();
            if (rCnt == 0) {
                statistics = currNodeStatsmap;
            } else {
                for (Entry<Integer, NodeStats> entry : statistics.entrySet()) {
                    NodeStats resultNodeStats = entry.getValue();
                    mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
                }
            }
        } catch (ExecutionException e) {
            throw new RuntimeException(e);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        rCnt += 1;
    }
    LOG.debug("Compute stats time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
    LOG.info("worker count is {}, error is {}, and stats size is {}. weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", count, trainError, statistics.size(), weightedTrainCount, weightedValidationCount, trainError, validationError);
    return new DTWorkerParams(weightedTrainCount, weightedValidationCount, trainError, validationError, statistics);
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ExecutorCompletionService(java.util.concurrent.ExecutorCompletionService) NodeStats(ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats) Entry(java.util.Map.Entry) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Random(java.util.Random) List(java.util.List) ArrayList(java.util.ArrayList) MemoryLimitedList(ml.shifu.guagua.util.MemoryLimitedList) ExecutionException(java.util.concurrent.ExecutionException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap)

Example 2 with NodeStats

use of ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats in project shifu by ShifuML.

the class DTWorker method initTodoNodeStats.

private Map<Integer, NodeStats> initTodoNodeStats(Map<Integer, TreeNode> todoNodes) {
    Map<Integer, NodeStats> statistics = new HashMap<Integer, NodeStats>(todoNodes.size(), 1f);
    for (Map.Entry<Integer, TreeNode> entry : todoNodes.entrySet()) {
        List<Integer> features = entry.getValue().getFeatures();
        if (features.isEmpty()) {
            features = getAllValidFeatures();
        }
        Map<Integer, double[]> featureStatistics = new HashMap<Integer, double[]>(features.size(), 1f);
        for (Integer columnNum : features) {
            ColumnConfig columnConfig = this.columnConfigList.get(columnNum);
            if (columnConfig.isNumerical()) {
                // TODO, how to process null bin
                int featureStatsSize = columnConfig.getBinBoundary().size() * this.impurity.getStatsSize();
                featureStatistics.put(columnNum, new double[featureStatsSize]);
            } else if (columnConfig.isCategorical()) {
                // the last one is for invalid value category like ?, *, ...
                int featureStatsSize = (columnConfig.getBinCategory().size() + 1) * this.impurity.getStatsSize();
                featureStatistics.put(columnNum, new double[featureStatsSize]);
            }
        }
        NodeStats nodeStats = new NodeStats(entry.getValue().getTreeId(), entry.getValue().getNode().getId(), featureStatistics);
        statistics.put(entry.getKey(), nodeStats);
    }
    return statistics;
}
Also used : NodeStats(ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap)

Example 3 with NodeStats

use of ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats in project shifu by ShifuML.

the class DTMaster method doCompute.

@Override
public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) {
    if (context.isFirstIteration()) {
        return buildInitialMasterParams();
    }
    if (this.cpMasterParams != null) {
        DTMasterParams tmpMasterParams = rebuildRecoverMasterResultDepthList();
        // set it to null to avoid send it in next iteration
        this.cpMasterParams = null;
        if (this.isGBDT) {
            // don't need to send full trees because worker will get existing models from HDFS
            // only set last tree to do node stats, no need check switch to next tree because of message may be send
            // to worker already
            tmpMasterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
            // set tmp trees for DTOutput
            tmpMasterParams.setTmpTrees(this.trees);
        }
        return tmpMasterParams;
    }
    boolean isFirst = false;
    Map<Integer, NodeStats> nodeStatsMap = null;
    double trainError = 0d, validationError = 0d;
    double weightedTrainCount = 0d, weightedValidationCount = 0d;
    for (DTWorkerParams params : context.getWorkerResults()) {
        if (!isFirst) {
            isFirst = true;
            nodeStatsMap = params.getNodeStatsMap();
        } else {
            Map<Integer, NodeStats> currNodeStatsmap = params.getNodeStatsMap();
            for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
                NodeStats resultNodeStats = entry.getValue();
                mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
            }
            // set to null after merging, release memory at the earliest stage
            params.setNodeStatsMap(null);
        }
        trainError += params.getTrainError();
        validationError += params.getValidationError();
        weightedTrainCount += params.getTrainCount();
        weightedValidationCount += params.getValidationCount();
    }
    for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
        NodeStats nodeStats = entry.getValue();
        int treeId = nodeStats.getTreeId();
        Node doneNode = Node.getNode(trees.get(treeId).getNode(), nodeStats.getNodeId());
        // doneNode, NodeStats
        Map<Integer, double[]> statistics = nodeStats.getFeatureStatistics();
        List<GainInfo> gainList = new ArrayList<GainInfo>();
        for (Entry<Integer, double[]> gainEntry : statistics.entrySet()) {
            int columnNum = gainEntry.getKey();
            ColumnConfig config = this.columnConfigList.get(columnNum);
            double[] statsArray = gainEntry.getValue();
            GainInfo gainInfo = this.impurity.computeImpurity(statsArray, config);
            if (gainInfo != null) {
                gainList.add(gainInfo);
            }
        }
        GainInfo maxGainInfo = GainInfo.getGainInfoByMaxGain(gainList);
        if (maxGainInfo == null) {
            // null gain info, set to leaf and continue next stats
            doneNode.setLeaf(true);
            continue;
        }
        populateGainInfoToNode(treeId, doneNode, maxGainInfo);
        if (this.isLeafWise) {
            boolean isNotSplit = maxGainInfo.getGain() <= 0d;
            if (!isNotSplit) {
                this.toSplitQueue.offer(new TreeNode(treeId, doneNode));
            } else {
                LOG.info("Node {} in tree {} is not to be split", doneNode.getId(), treeId);
            }
        } else {
            boolean isLeaf = maxGainInfo.getGain() <= 0d || Node.indexToLevel(doneNode.getId()) == this.maxDepth;
            doneNode.setLeaf(isLeaf);
            // level-wise is to split node when stats is ready
            splitNodeForLevelWisedTree(isLeaf, treeId, doneNode);
        }
    }
    if (this.isLeafWise) {
        // get node in toSplitQueue and split
        int currSplitIndex = 0;
        while (!toSplitQueue.isEmpty() && currSplitIndex < this.maxBatchSplitSize) {
            TreeNode treeNode = this.toSplitQueue.poll();
            splitNodeForLeafWisedTree(treeNode.getTreeId(), treeNode.getNode());
        }
    }
    Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>();
    double averageValidationError = validationError / weightedValidationCount;
    if (this.isGBDT && this.dtEarlyStopDecider != null && averageValidationError > 0) {
        this.dtEarlyStopDecider.add(averageValidationError);
        averageValidationError = this.dtEarlyStopDecider.getCurrentAverageValue();
    }
    boolean vtTriggered = false;
    // if validationTolerance == 0d, means vt check is not enabled
    if (validationTolerance > 0d && Math.abs(this.bestValidationError - averageValidationError) < this.validationTolerance * averageValidationError) {
        LOG.debug("Debug: bestValidationError {}, averageValidationError {}, validationTolerance {}", bestValidationError, averageValidationError, validationTolerance);
        vtTriggered = true;
    }
    if (averageValidationError < this.bestValidationError) {
        this.bestValidationError = averageValidationError;
    }
    // validation error is averageValidationError * weightedValidationCount because of here averageValidationError
    // is divided by validation count.
    DTMasterParams masterParams = new DTMasterParams(weightedTrainCount, trainError, weightedValidationCount, averageValidationError * weightedValidationCount);
    if (toDoQueue.isEmpty()) {
        if (this.isGBDT) {
            TreeNode treeNode = this.trees.get(this.trees.size() - 1);
            Node node = treeNode.getNode();
            if (this.trees.size() >= this.treeNum) {
                // if all trees including trees read from existing model over treeNum, stop the whole process.
                masterParams.setHalt(true);
                LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (node.getLeft() == null && node.getRight() == null) {
                // if very good performance, here can be some issues, say you'd like to get 5 trees, but in the 2nd
                // tree, you get one perfect tree, no need continue but warn users about such issue: set
                // BaggingSampleRate not to 1 can solve such issue to avoid overfit
                masterParams.setHalt(true);
                LOG.warn("Tree is learned 100% well, there must be overfit here, please tune BaggingSampleRate, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (this.dtEarlyStopDecider != null && (this.enableEarlyStop && this.dtEarlyStopDecider.canStop())) {
                masterParams.setHalt(true);
                LOG.info("Early stop identified, training is stopped in iteration {}.", context.getCurrentIteration());
            } else if (vtTriggered) {
                LOG.info("Early stop training by validation tolerance.");
                masterParams.setHalt(true);
            } else {
                // set first tree to true even after ROOT node is set in next tree
                masterParams.setFirstTree(this.trees.size() == 1);
                // finish current tree, no need features information
                treeNode.setFeatures(null);
                TreeNode newRootNode = new TreeNode(this.trees.size(), new Node(Node.ROOT_INDEX), this.learningRate);
                LOG.info("The {} tree is to be built.", this.trees.size());
                this.trees.add(newRootNode);
                newRootNode.setFeatures(getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate));
                // only one node
                todoNodes.put(0, newRootNode);
                masterParams.setTodoNodes(todoNodes);
                // set switch flag
                masterParams.setSwitchToNextTree(true);
            }
        } else {
            // for rf
            masterParams.setHalt(true);
            LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
        }
    } else {
        int nodeIndexInGroup = 0;
        long currMem = 0L;
        List<Integer> depthList = new ArrayList<Integer>();
        if (this.isGBDT) {
            depthList.add(-1);
        }
        if (isRF) {
            for (int i = 0; i < this.trees.size(); i++) {
                depthList.add(-1);
            }
        }
        while (!toDoQueue.isEmpty() && currMem <= this.maxStatsMemory) {
            TreeNode node = this.toDoQueue.poll();
            int treeId = node.getTreeId();
            int oldDepth = this.isGBDT ? depthList.get(0) : depthList.get(treeId);
            int currDepth = Node.indexToLevel(node.getNode().getId());
            if (currDepth > oldDepth) {
                if (this.isGBDT) {
                    // gbdt only for last depth
                    depthList.set(0, currDepth);
                } else {
                    depthList.set(treeId, currDepth);
                }
            }
            List<Integer> subsetFeatures = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate);
            node.setFeatures(subsetFeatures);
            currMem += getStatsMem(subsetFeatures);
            todoNodes.put(nodeIndexInGroup, node);
            nodeIndexInGroup += 1;
        }
        masterParams.setTreeDepth(depthList);
        masterParams.setTodoNodes(todoNodes);
        masterParams.setSwitchToNextTree(false);
        masterParams.setContinuousRunningStart(false);
        masterParams.setFirstTree(this.trees.size() == 1);
        LOG.info("Todo node size is {}", todoNodes.size());
    }
    if (this.isGBDT) {
        if (masterParams.isSwitchToNextTree()) {
            // send last full growth tree and current todo ROOT node tree
            masterParams.setTrees(trees.subList(trees.size() - 2, trees.size()));
        } else {
            // only send current trees
            masterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
        }
    }
    if (this.isRF) {
        // elements in todoTrees are also the same reference in this.trees, reuse the same object to save memory.
        if (masterParams.getTreeDepth().size() == this.trees.size()) {
            // if normal iteration
            List<TreeNode> todoTrees = new ArrayList<TreeNode>();
            for (int i = 0; i < trees.size(); i++) {
                if (masterParams.getTreeDepth().get(i) >= 0) {
                    // such tree in current iteration treeDepth is not -1, add it to todoTrees.
                    todoTrees.add(trees.get(i));
                } else {
                    // mock a TreeNode instance to make sure no surprise in further serialization. In fact
                    // meaningless.
                    todoTrees.add(new TreeNode(i, new Node(Node.INVALID_INDEX), 1d));
                }
            }
            masterParams.setTrees(todoTrees);
        } else {
            // if last iteration without maxDepthList
            masterParams.setTrees(trees);
        }
    }
    if (this.isGBDT) {
        // set tmp trees to DTOutput
        masterParams.setTmpTrees(this.trees);
    }
    if (context.getCurrentIteration() % 100 == 0) {
        // every 100 iterations do gc explicitly to avoid one case:
        // mapper memory is 2048M and final in our cluster, if -Xmx is 2G, then occasionally oom issue.
        // to fix this issue: 1. set -Xmx to 1800m; 2. call gc to drop unused memory at early stage.
        // this is ugly and if it is stable with 1800m, this line should be removed
        Thread gcThread = new Thread(new Runnable() {

            @Override
            public void run() {
                System.gc();
            }
        });
        gcThread.setDaemon(true);
        gcThread.start();
    }
    // before master result, do checkpoint according to n iteration set by user
    doCheckPoint(context, masterParams, context.getCurrentIteration());
    LOG.debug("weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", weightedTrainCount, weightedValidationCount, trainError, validationError);
    return masterParams;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) NodeStats(ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats)

Aggregations

HashMap (java.util.HashMap)3 NodeStats (ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats)3 ArrayList (java.util.ArrayList)2 Map (java.util.Map)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 ConcurrentMap (java.util.concurrent.ConcurrentMap)2 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 IOException (java.io.IOException)1 List (java.util.List)1 Entry (java.util.Map.Entry)1 Random (java.util.Random)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1 ExecutionException (java.util.concurrent.ExecutionException)1 ExecutorCompletionService (java.util.concurrent.ExecutorCompletionService)1 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)1 MemoryLimitedList (ml.shifu.guagua.util.MemoryLimitedList)1