use of ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats in project shifu by ShifuML.
the class DTWorker method doCompute.
/*
* (non-Javadoc)
*
* @see ml.shifu.guagua.worker.AbstractWorkerComputable#doCompute(ml.shifu.guagua.worker.WorkerContext)
*/
@Override
public DTWorkerParams doCompute(WorkerContext<DTMasterParams, DTWorkerParams> context) {
if (context.isFirstIteration()) {
return new DTWorkerParams();
}
DTMasterParams lastMasterResult = context.getLastMasterResult();
final List<TreeNode> trees = lastMasterResult.getTrees();
final Map<Integer, TreeNode> todoNodes = lastMasterResult.getTodoNodes();
if (todoNodes == null) {
return new DTWorkerParams();
}
LOG.info("Start to work: todoNodes size is {}", todoNodes.size());
Map<Integer, NodeStats> statistics = initTodoNodeStats(todoNodes);
double trainError = 0d, validationError = 0d;
double weightedTrainCount = 0d, weightedValidationCount = 0d;
// renew random seed
if (this.isGBDT && !this.gbdtSampleWithReplacement && lastMasterResult.isSwitchToNextTree()) {
this.baggingRandomMap = new HashMap<Integer, Random>();
}
long start = System.nanoTime();
for (Data data : this.trainingData) {
if (this.isRF) {
for (TreeNode treeNode : trees) {
if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
continue;
}
Node predictNode = predictNodeIndex(treeNode.getNode(), data, true);
if (predictNode.getPredict() != null) {
// only update when not in first node, for treeNode, no predict statistics at that time
float weight = data.subsampleWeights[treeNode.getTreeId()];
if (Float.compare(weight, 0f) == 0) {
// oob data, no need to do weighting
validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
weightedValidationCount += data.significance;
} else {
trainError += weight * data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
weightedTrainCount += weight * data.significance;
}
}
}
}
if (this.isGBDT) {
if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) {
recoverGBTData(context, data.output, data.predict, data, false);
trainError += data.significance * loss.computeError(data.predict, data.label);
weightedTrainCount += data.significance;
} else {
if (isNeedRecoverGBDTPredict) {
if (this.recoverTrees == null) {
this.recoverTrees = recoverCurrentTrees();
}
// recover gbdt data for fail over
recoverGBTData(context, data.output, data.predict, data, true);
}
int currTreeIndex = trees.size() - 1;
if (lastMasterResult.isSwitchToNextTree()) {
if (currTreeIndex >= 1) {
Node node = trees.get(currTreeIndex - 1).getNode();
Node predictNode = predictNodeIndex(node, data, false);
if (predictNode.getPredict() != null) {
double predict = predictNode.getPredict().getPredict();
// sending
if (context.getLastMasterResult().isFirstTree()) {
data.predict = (float) predict;
} else {
// random drop
boolean drop = (this.dropOutRate > 0.0 && dropOutRandom.nextDouble() < this.dropOutRate);
if (!drop) {
data.predict += (float) (this.learningRate * predict);
}
}
data.output = -1f * loss.computeGradient(data.predict, data.label);
}
// if not sampling with replacement in gbdt, renew bagging sample rate in next tree
if (!this.gbdtSampleWithReplacement) {
Random random = null;
int classValue = (int) (data.label + 0.01f);
if (this.isStratifiedSampling) {
random = baggingRandomMap.get(classValue);
if (random == null) {
random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
baggingRandomMap.put(classValue, random);
}
} else {
random = baggingRandomMap.get(0);
if (random == null) {
random = DTrainUtils.generateRandomBySampleSeed(modelConfig.getTrain().getBaggingSampleSeed(), CommonConstants.NOT_CONFIGURED_BAGGING_SEED);
baggingRandomMap.put(0, random);
}
}
if (random.nextDouble() <= modelConfig.getTrain().getBaggingSampleRate()) {
data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 1f;
} else {
data.subsampleWeights[currTreeIndex % data.subsampleWeights.length] = 0f;
}
}
}
}
if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
Node currTree = trees.get(currTreeIndex).getNode();
Node predictNode = predictNodeIndex(currTree, data, true);
if (predictNode.getPredict() != null) {
trainError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
weightedTrainCount += data.significance;
}
} else {
trainError += data.significance * loss.computeError(data.predict, data.label);
weightedTrainCount += data.significance;
}
}
}
}
LOG.debug("Compute train error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
if (validationData != null) {
start = System.nanoTime();
for (Data data : this.validationData) {
if (this.isRF) {
for (TreeNode treeNode : trees) {
if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
continue;
}
Node predictNode = predictNodeIndex(treeNode.getNode(), data, true);
if (predictNode.getPredict() != null) {
// only update when not in first node, for treeNode, no predict statistics at that time
validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
weightedValidationCount += data.significance;
}
}
}
if (this.isGBDT) {
if (this.isContinuousEnabled && lastMasterResult.isContinuousRunningStart()) {
recoverGBTData(context, data.output, data.predict, data, false);
validationError += data.significance * loss.computeError(data.predict, data.label);
weightedValidationCount += data.significance;
} else {
if (isNeedRecoverGBDTPredict) {
if (this.recoverTrees == null) {
this.recoverTrees = recoverCurrentTrees();
}
// recover gbdt data for fail over
recoverGBTData(context, data.output, data.predict, data, true);
}
int currTreeIndex = trees.size() - 1;
if (lastMasterResult.isSwitchToNextTree()) {
if (currTreeIndex >= 1) {
Node node = trees.get(currTreeIndex - 1).getNode();
Node predictNode = predictNodeIndex(node, data, false);
if (predictNode.getPredict() != null) {
double predict = predictNode.getPredict().getPredict();
if (context.getLastMasterResult().isFirstTree()) {
data.predict = (float) predict;
} else {
data.predict += (float) (this.learningRate * predict);
}
data.output = -1f * loss.computeGradient(data.predict, data.label);
}
}
}
if (context.getLastMasterResult().isFirstTree() && !lastMasterResult.isSwitchToNextTree()) {
Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, true);
if (predictNode.getPredict() != null) {
validationError += data.significance * loss.computeError((float) (predictNode.getPredict().getPredict()), data.label);
weightedValidationCount += data.significance;
}
} else {
validationError += data.significance * loss.computeError(data.predict, data.label);
weightedValidationCount += data.significance;
}
}
}
}
LOG.debug("Compute val error time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
}
if (this.isGBDT) {
// reset trees to null to save memory
this.recoverTrees = null;
if (this.isNeedRecoverGBDTPredict) {
// no need recover again
this.isNeedRecoverGBDTPredict = false;
}
}
start = System.nanoTime();
CompletionService<Map<Integer, NodeStats>> completionService = new ExecutorCompletionService<Map<Integer, NodeStats>>(this.threadPool);
int realThreadCount = 0;
LOG.debug("while todo size {}", todoNodes.size());
int realRecords = this.trainingData.size();
int realThreads = this.workerThreadCount > realRecords ? realRecords : this.workerThreadCount;
int[] trainLows = new int[realThreads];
int[] trainHighs = new int[realThreads];
int stepCount = realRecords / realThreads;
if (realRecords % realThreads != 0) {
// move step count to append last gap to avoid last thread worse 2*stepCount-1
stepCount += (realRecords % realThreads) / stepCount;
}
for (int i = 0; i < realThreads; i++) {
trainLows[i] = i * stepCount;
if (i != realThreads - 1) {
trainHighs[i] = trainLows[i] + stepCount - 1;
} else {
trainHighs[i] = realRecords - 1;
}
}
for (int i = 0; i < realThreads; i++) {
final Map<Integer, TreeNode> localTodoNodes = new HashMap<Integer, TreeNode>(todoNodes);
final Map<Integer, NodeStats> localStatistics = initTodoNodeStats(todoNodes);
final int startIndex = trainLows[i];
final int endIndex = trainHighs[i];
LOG.info("Thread {} todo size {} stats size {} start index {} end index {}", i, localTodoNodes.size(), localStatistics.size(), startIndex, endIndex);
if (localTodoNodes.size() == 0) {
continue;
}
realThreadCount += 1;
completionService.submit(new Callable<Map<Integer, NodeStats>>() {
@Override
public Map<Integer, NodeStats> call() throws Exception {
long start = System.nanoTime();
List<Integer> nodeIndexes = new ArrayList<Integer>(trees.size());
for (int j = startIndex; j <= endIndex; j++) {
Data data = DTWorker.this.trainingData.get(j);
nodeIndexes.clear();
if (DTWorker.this.isRF) {
for (TreeNode treeNode : trees) {
if (treeNode.getNode().getId() == Node.INVALID_INDEX) {
nodeIndexes.add(Node.INVALID_INDEX);
} else {
Node predictNode = predictNodeIndex(treeNode.getNode(), data, false);
nodeIndexes.add(predictNode.getId());
}
}
}
if (DTWorker.this.isGBDT) {
int currTreeIndex = trees.size() - 1;
Node predictNode = predictNodeIndex(trees.get(currTreeIndex).getNode(), data, false);
// update node index
nodeIndexes.add(predictNode.getId());
}
for (Map.Entry<Integer, TreeNode> entry : localTodoNodes.entrySet()) {
// only do statistics on effective data
Node todoNode = entry.getValue().getNode();
int treeId = entry.getValue().getTreeId();
int currPredictIndex = 0;
if (DTWorker.this.isRF) {
currPredictIndex = nodeIndexes.get(entry.getValue().getTreeId());
}
if (DTWorker.this.isGBDT) {
currPredictIndex = nodeIndexes.get(0);
}
if (todoNode.getId() == currPredictIndex) {
List<Integer> features = entry.getValue().getFeatures();
if (features.isEmpty()) {
features = getAllValidFeatures();
}
for (Integer columnNum : features) {
double[] featuerStatistic = localStatistics.get(entry.getKey()).getFeatureStatistics().get(columnNum);
float weight = data.subsampleWeights[treeId % data.subsampleWeights.length];
if (Float.compare(weight, 0f) != 0) {
// only compute weight is not 0
short binIndex = data.inputs[DTWorker.this.inputIndexMap.get(columnNum)];
DTWorker.this.impurity.featureUpdate(featuerStatistic, binIndex, data.output, data.significance, weight);
}
}
}
}
}
LOG.debug("Thread computing stats time is {}ms in thread {}", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start), Thread.currentThread().getName());
return localStatistics;
}
});
}
int rCnt = 0;
while (rCnt < realThreadCount) {
try {
Map<Integer, NodeStats> currNodeStatsmap = completionService.take().get();
if (rCnt == 0) {
statistics = currNodeStatsmap;
} else {
for (Entry<Integer, NodeStats> entry : statistics.entrySet()) {
NodeStats resultNodeStats = entry.getValue();
mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
}
}
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
rCnt += 1;
}
LOG.debug("Compute stats time is {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start));
LOG.info("worker count is {}, error is {}, and stats size is {}. weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", count, trainError, statistics.size(), weightedTrainCount, weightedValidationCount, trainError, validationError);
return new DTWorkerParams(weightedTrainCount, weightedValidationCount, trainError, validationError, statistics);
}
use of ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats in project shifu by ShifuML.
the class DTWorker method initTodoNodeStats.
private Map<Integer, NodeStats> initTodoNodeStats(Map<Integer, TreeNode> todoNodes) {
Map<Integer, NodeStats> statistics = new HashMap<Integer, NodeStats>(todoNodes.size(), 1f);
for (Map.Entry<Integer, TreeNode> entry : todoNodes.entrySet()) {
List<Integer> features = entry.getValue().getFeatures();
if (features.isEmpty()) {
features = getAllValidFeatures();
}
Map<Integer, double[]> featureStatistics = new HashMap<Integer, double[]>(features.size(), 1f);
for (Integer columnNum : features) {
ColumnConfig columnConfig = this.columnConfigList.get(columnNum);
if (columnConfig.isNumerical()) {
// TODO, how to process null bin
int featureStatsSize = columnConfig.getBinBoundary().size() * this.impurity.getStatsSize();
featureStatistics.put(columnNum, new double[featureStatsSize]);
} else if (columnConfig.isCategorical()) {
// the last one is for invalid value category like ?, *, ...
int featureStatsSize = (columnConfig.getBinCategory().size() + 1) * this.impurity.getStatsSize();
featureStatistics.put(columnNum, new double[featureStatsSize]);
}
}
NodeStats nodeStats = new NodeStats(entry.getValue().getTreeId(), entry.getValue().getNode().getId(), featureStatistics);
statistics.put(entry.getKey(), nodeStats);
}
return statistics;
}
use of ml.shifu.shifu.core.dtrain.dt.DTWorkerParams.NodeStats in project shifu by ShifuML.
the class DTMaster method doCompute.
@Override
public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) {
if (context.isFirstIteration()) {
return buildInitialMasterParams();
}
if (this.cpMasterParams != null) {
DTMasterParams tmpMasterParams = rebuildRecoverMasterResultDepthList();
// set it to null to avoid send it in next iteration
this.cpMasterParams = null;
if (this.isGBDT) {
// don't need to send full trees because worker will get existing models from HDFS
// only set last tree to do node stats, no need check switch to next tree because of message may be send
// to worker already
tmpMasterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
// set tmp trees for DTOutput
tmpMasterParams.setTmpTrees(this.trees);
}
return tmpMasterParams;
}
boolean isFirst = false;
Map<Integer, NodeStats> nodeStatsMap = null;
double trainError = 0d, validationError = 0d;
double weightedTrainCount = 0d, weightedValidationCount = 0d;
for (DTWorkerParams params : context.getWorkerResults()) {
if (!isFirst) {
isFirst = true;
nodeStatsMap = params.getNodeStatsMap();
} else {
Map<Integer, NodeStats> currNodeStatsmap = params.getNodeStatsMap();
for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
NodeStats resultNodeStats = entry.getValue();
mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
}
// set to null after merging, release memory at the earliest stage
params.setNodeStatsMap(null);
}
trainError += params.getTrainError();
validationError += params.getValidationError();
weightedTrainCount += params.getTrainCount();
weightedValidationCount += params.getValidationCount();
}
for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
NodeStats nodeStats = entry.getValue();
int treeId = nodeStats.getTreeId();
Node doneNode = Node.getNode(trees.get(treeId).getNode(), nodeStats.getNodeId());
// doneNode, NodeStats
Map<Integer, double[]> statistics = nodeStats.getFeatureStatistics();
List<GainInfo> gainList = new ArrayList<GainInfo>();
for (Entry<Integer, double[]> gainEntry : statistics.entrySet()) {
int columnNum = gainEntry.getKey();
ColumnConfig config = this.columnConfigList.get(columnNum);
double[] statsArray = gainEntry.getValue();
GainInfo gainInfo = this.impurity.computeImpurity(statsArray, config);
if (gainInfo != null) {
gainList.add(gainInfo);
}
}
GainInfo maxGainInfo = GainInfo.getGainInfoByMaxGain(gainList);
if (maxGainInfo == null) {
// null gain info, set to leaf and continue next stats
doneNode.setLeaf(true);
continue;
}
populateGainInfoToNode(treeId, doneNode, maxGainInfo);
if (this.isLeafWise) {
boolean isNotSplit = maxGainInfo.getGain() <= 0d;
if (!isNotSplit) {
this.toSplitQueue.offer(new TreeNode(treeId, doneNode));
} else {
LOG.info("Node {} in tree {} is not to be split", doneNode.getId(), treeId);
}
} else {
boolean isLeaf = maxGainInfo.getGain() <= 0d || Node.indexToLevel(doneNode.getId()) == this.maxDepth;
doneNode.setLeaf(isLeaf);
// level-wise is to split node when stats is ready
splitNodeForLevelWisedTree(isLeaf, treeId, doneNode);
}
}
if (this.isLeafWise) {
// get node in toSplitQueue and split
int currSplitIndex = 0;
while (!toSplitQueue.isEmpty() && currSplitIndex < this.maxBatchSplitSize) {
TreeNode treeNode = this.toSplitQueue.poll();
splitNodeForLeafWisedTree(treeNode.getTreeId(), treeNode.getNode());
}
}
Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>();
double averageValidationError = validationError / weightedValidationCount;
if (this.isGBDT && this.dtEarlyStopDecider != null && averageValidationError > 0) {
this.dtEarlyStopDecider.add(averageValidationError);
averageValidationError = this.dtEarlyStopDecider.getCurrentAverageValue();
}
boolean vtTriggered = false;
// if validationTolerance == 0d, means vt check is not enabled
if (validationTolerance > 0d && Math.abs(this.bestValidationError - averageValidationError) < this.validationTolerance * averageValidationError) {
LOG.debug("Debug: bestValidationError {}, averageValidationError {}, validationTolerance {}", bestValidationError, averageValidationError, validationTolerance);
vtTriggered = true;
}
if (averageValidationError < this.bestValidationError) {
this.bestValidationError = averageValidationError;
}
// validation error is averageValidationError * weightedValidationCount because of here averageValidationError
// is divided by validation count.
DTMasterParams masterParams = new DTMasterParams(weightedTrainCount, trainError, weightedValidationCount, averageValidationError * weightedValidationCount);
if (toDoQueue.isEmpty()) {
if (this.isGBDT) {
TreeNode treeNode = this.trees.get(this.trees.size() - 1);
Node node = treeNode.getNode();
if (this.trees.size() >= this.treeNum) {
// if all trees including trees read from existing model over treeNum, stop the whole process.
masterParams.setHalt(true);
LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
} else if (node.getLeft() == null && node.getRight() == null) {
// if very good performance, here can be some issues, say you'd like to get 5 trees, but in the 2nd
// tree, you get one perfect tree, no need continue but warn users about such issue: set
// BaggingSampleRate not to 1 can solve such issue to avoid overfit
masterParams.setHalt(true);
LOG.warn("Tree is learned 100% well, there must be overfit here, please tune BaggingSampleRate, training is stopped in iteration {}.", context.getCurrentIteration());
} else if (this.dtEarlyStopDecider != null && (this.enableEarlyStop && this.dtEarlyStopDecider.canStop())) {
masterParams.setHalt(true);
LOG.info("Early stop identified, training is stopped in iteration {}.", context.getCurrentIteration());
} else if (vtTriggered) {
LOG.info("Early stop training by validation tolerance.");
masterParams.setHalt(true);
} else {
// set first tree to true even after ROOT node is set in next tree
masterParams.setFirstTree(this.trees.size() == 1);
// finish current tree, no need features information
treeNode.setFeatures(null);
TreeNode newRootNode = new TreeNode(this.trees.size(), new Node(Node.ROOT_INDEX), this.learningRate);
LOG.info("The {} tree is to be built.", this.trees.size());
this.trees.add(newRootNode);
newRootNode.setFeatures(getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate));
// only one node
todoNodes.put(0, newRootNode);
masterParams.setTodoNodes(todoNodes);
// set switch flag
masterParams.setSwitchToNextTree(true);
}
} else {
// for rf
masterParams.setHalt(true);
LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
}
} else {
int nodeIndexInGroup = 0;
long currMem = 0L;
List<Integer> depthList = new ArrayList<Integer>();
if (this.isGBDT) {
depthList.add(-1);
}
if (isRF) {
for (int i = 0; i < this.trees.size(); i++) {
depthList.add(-1);
}
}
while (!toDoQueue.isEmpty() && currMem <= this.maxStatsMemory) {
TreeNode node = this.toDoQueue.poll();
int treeId = node.getTreeId();
int oldDepth = this.isGBDT ? depthList.get(0) : depthList.get(treeId);
int currDepth = Node.indexToLevel(node.getNode().getId());
if (currDepth > oldDepth) {
if (this.isGBDT) {
// gbdt only for last depth
depthList.set(0, currDepth);
} else {
depthList.set(treeId, currDepth);
}
}
List<Integer> subsetFeatures = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate);
node.setFeatures(subsetFeatures);
currMem += getStatsMem(subsetFeatures);
todoNodes.put(nodeIndexInGroup, node);
nodeIndexInGroup += 1;
}
masterParams.setTreeDepth(depthList);
masterParams.setTodoNodes(todoNodes);
masterParams.setSwitchToNextTree(false);
masterParams.setContinuousRunningStart(false);
masterParams.setFirstTree(this.trees.size() == 1);
LOG.info("Todo node size is {}", todoNodes.size());
}
if (this.isGBDT) {
if (masterParams.isSwitchToNextTree()) {
// send last full growth tree and current todo ROOT node tree
masterParams.setTrees(trees.subList(trees.size() - 2, trees.size()));
} else {
// only send current trees
masterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
}
}
if (this.isRF) {
// elements in todoTrees are also the same reference in this.trees, reuse the same object to save memory.
if (masterParams.getTreeDepth().size() == this.trees.size()) {
// if normal iteration
List<TreeNode> todoTrees = new ArrayList<TreeNode>();
for (int i = 0; i < trees.size(); i++) {
if (masterParams.getTreeDepth().get(i) >= 0) {
// such tree in current iteration treeDepth is not -1, add it to todoTrees.
todoTrees.add(trees.get(i));
} else {
// mock a TreeNode instance to make sure no surprise in further serialization. In fact
// meaningless.
todoTrees.add(new TreeNode(i, new Node(Node.INVALID_INDEX), 1d));
}
}
masterParams.setTrees(todoTrees);
} else {
// if last iteration without maxDepthList
masterParams.setTrees(trees);
}
}
if (this.isGBDT) {
// set tmp trees to DTOutput
masterParams.setTmpTrees(this.trees);
}
if (context.getCurrentIteration() % 100 == 0) {
// every 100 iterations do gc explicitly to avoid one case:
// mapper memory is 2048M and final in our cluster, if -Xmx is 2G, then occasionally oom issue.
// to fix this issue: 1. set -Xmx to 1800m; 2. call gc to drop unused memory at early stage.
// this is ugly and if it is stable with 1800m, this line should be removed
Thread gcThread = new Thread(new Runnable() {
@Override
public void run() {
System.gc();
}
});
gcThread.setDaemon(true);
gcThread.start();
}
// before master result, do checkpoint according to n iteration set by user
doCheckPoint(context, masterParams, context.getCurrentIteration());
LOG.debug("weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", weightedTrainCount, weightedValidationCount, trainError, validationError);
return masterParams;
}
Aggregations