use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTrainUtils method getNumericalIds.
public static List<Integer> getNumericalIds(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
List<Integer> numericalIds = new ArrayList<>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig config : columnConfigList) {
if (isAfterVarSelect) {
if (config.isNumerical() && config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
numericalIds.add(config.getColumnNum());
}
} else {
if (config.isNumerical() && !config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
numericalIds.add(config.getColumnNum());
}
}
}
return numericalIds;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTrainUtils method getCategoricalIds.
public static List<Integer> getCategoricalIds(List<ColumnConfig> columnConfigList, boolean isAfterVarSelect) {
List<Integer> results = new ArrayList<>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig config : columnConfigList) {
if (isAfterVarSelect) {
if (config.isFinalSelect() && !config.isTarget() && !config.isMeta() && config.isCategorical()) {
results.add(config.getColumnNum());
}
} else {
if (!config.isTarget() && !config.isMeta() && CommonUtils.isGoodCandidate(config, hasCandidates) && config.isCategorical()) {
results.add(config.getColumnNum());
}
}
}
return results;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class DTMaster method doCompute.
@Override
public DTMasterParams doCompute(MasterContext<DTMasterParams, DTWorkerParams> context) {
if (context.isFirstIteration()) {
return buildInitialMasterParams();
}
if (this.cpMasterParams != null) {
DTMasterParams tmpMasterParams = rebuildRecoverMasterResultDepthList();
// set it to null to avoid send it in next iteration
this.cpMasterParams = null;
if (this.isGBDT) {
// don't need to send full trees because worker will get existing models from HDFS
// only set last tree to do node stats, no need check switch to next tree because of message may be send
// to worker already
tmpMasterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
// set tmp trees for DTOutput
tmpMasterParams.setTmpTrees(this.trees);
}
return tmpMasterParams;
}
boolean isFirst = false;
Map<Integer, NodeStats> nodeStatsMap = null;
double trainError = 0d, validationError = 0d;
double weightedTrainCount = 0d, weightedValidationCount = 0d;
for (DTWorkerParams params : context.getWorkerResults()) {
if (!isFirst) {
isFirst = true;
nodeStatsMap = params.getNodeStatsMap();
} else {
Map<Integer, NodeStats> currNodeStatsmap = params.getNodeStatsMap();
for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
NodeStats resultNodeStats = entry.getValue();
mergeNodeStats(resultNodeStats, currNodeStatsmap.get(entry.getKey()));
}
// set to null after merging, release memory at the earliest stage
params.setNodeStatsMap(null);
}
trainError += params.getTrainError();
validationError += params.getValidationError();
weightedTrainCount += params.getTrainCount();
weightedValidationCount += params.getValidationCount();
}
for (Entry<Integer, NodeStats> entry : nodeStatsMap.entrySet()) {
NodeStats nodeStats = entry.getValue();
int treeId = nodeStats.getTreeId();
Node doneNode = Node.getNode(trees.get(treeId).getNode(), nodeStats.getNodeId());
// doneNode, NodeStats
Map<Integer, double[]> statistics = nodeStats.getFeatureStatistics();
List<GainInfo> gainList = new ArrayList<GainInfo>();
for (Entry<Integer, double[]> gainEntry : statistics.entrySet()) {
int columnNum = gainEntry.getKey();
ColumnConfig config = this.columnConfigList.get(columnNum);
double[] statsArray = gainEntry.getValue();
GainInfo gainInfo = this.impurity.computeImpurity(statsArray, config);
if (gainInfo != null) {
gainList.add(gainInfo);
}
}
GainInfo maxGainInfo = GainInfo.getGainInfoByMaxGain(gainList);
if (maxGainInfo == null) {
// null gain info, set to leaf and continue next stats
doneNode.setLeaf(true);
continue;
}
populateGainInfoToNode(treeId, doneNode, maxGainInfo);
if (this.isLeafWise) {
boolean isNotSplit = maxGainInfo.getGain() <= 0d;
if (!isNotSplit) {
this.toSplitQueue.offer(new TreeNode(treeId, doneNode));
} else {
LOG.info("Node {} in tree {} is not to be split", doneNode.getId(), treeId);
}
} else {
boolean isLeaf = maxGainInfo.getGain() <= 0d || Node.indexToLevel(doneNode.getId()) == this.maxDepth;
doneNode.setLeaf(isLeaf);
// level-wise is to split node when stats is ready
splitNodeForLevelWisedTree(isLeaf, treeId, doneNode);
}
}
if (this.isLeafWise) {
// get node in toSplitQueue and split
int currSplitIndex = 0;
while (!toSplitQueue.isEmpty() && currSplitIndex < this.maxBatchSplitSize) {
TreeNode treeNode = this.toSplitQueue.poll();
splitNodeForLeafWisedTree(treeNode.getTreeId(), treeNode.getNode());
}
}
Map<Integer, TreeNode> todoNodes = new HashMap<Integer, TreeNode>();
double averageValidationError = validationError / weightedValidationCount;
if (this.isGBDT && this.dtEarlyStopDecider != null && averageValidationError > 0) {
this.dtEarlyStopDecider.add(averageValidationError);
averageValidationError = this.dtEarlyStopDecider.getCurrentAverageValue();
}
boolean vtTriggered = false;
// if validationTolerance == 0d, means vt check is not enabled
if (validationTolerance > 0d && Math.abs(this.bestValidationError - averageValidationError) < this.validationTolerance * averageValidationError) {
LOG.debug("Debug: bestValidationError {}, averageValidationError {}, validationTolerance {}", bestValidationError, averageValidationError, validationTolerance);
vtTriggered = true;
}
if (averageValidationError < this.bestValidationError) {
this.bestValidationError = averageValidationError;
}
// validation error is averageValidationError * weightedValidationCount because of here averageValidationError
// is divided by validation count.
DTMasterParams masterParams = new DTMasterParams(weightedTrainCount, trainError, weightedValidationCount, averageValidationError * weightedValidationCount);
if (toDoQueue.isEmpty()) {
if (this.isGBDT) {
TreeNode treeNode = this.trees.get(this.trees.size() - 1);
Node node = treeNode.getNode();
if (this.trees.size() >= this.treeNum) {
// if all trees including trees read from existing model over treeNum, stop the whole process.
masterParams.setHalt(true);
LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
} else if (node.getLeft() == null && node.getRight() == null) {
// if very good performance, here can be some issues, say you'd like to get 5 trees, but in the 2nd
// tree, you get one perfect tree, no need continue but warn users about such issue: set
// BaggingSampleRate not to 1 can solve such issue to avoid overfit
masterParams.setHalt(true);
LOG.warn("Tree is learned 100% well, there must be overfit here, please tune BaggingSampleRate, training is stopped in iteration {}.", context.getCurrentIteration());
} else if (this.dtEarlyStopDecider != null && (this.enableEarlyStop && this.dtEarlyStopDecider.canStop())) {
masterParams.setHalt(true);
LOG.info("Early stop identified, training is stopped in iteration {}.", context.getCurrentIteration());
} else if (vtTriggered) {
LOG.info("Early stop training by validation tolerance.");
masterParams.setHalt(true);
} else {
// set first tree to true even after ROOT node is set in next tree
masterParams.setFirstTree(this.trees.size() == 1);
// finish current tree, no need features information
treeNode.setFeatures(null);
TreeNode newRootNode = new TreeNode(this.trees.size(), new Node(Node.ROOT_INDEX), this.learningRate);
LOG.info("The {} tree is to be built.", this.trees.size());
this.trees.add(newRootNode);
newRootNode.setFeatures(getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate));
// only one node
todoNodes.put(0, newRootNode);
masterParams.setTodoNodes(todoNodes);
// set switch flag
masterParams.setSwitchToNextTree(true);
}
} else {
// for rf
masterParams.setHalt(true);
LOG.info("Queue is empty, training is stopped in iteration {}.", context.getCurrentIteration());
}
} else {
int nodeIndexInGroup = 0;
long currMem = 0L;
List<Integer> depthList = new ArrayList<Integer>();
if (this.isGBDT) {
depthList.add(-1);
}
if (isRF) {
for (int i = 0; i < this.trees.size(); i++) {
depthList.add(-1);
}
}
while (!toDoQueue.isEmpty() && currMem <= this.maxStatsMemory) {
TreeNode node = this.toDoQueue.poll();
int treeId = node.getTreeId();
int oldDepth = this.isGBDT ? depthList.get(0) : depthList.get(treeId);
int currDepth = Node.indexToLevel(node.getNode().getId());
if (currDepth > oldDepth) {
if (this.isGBDT) {
// gbdt only for last depth
depthList.set(0, currDepth);
} else {
depthList.set(treeId, currDepth);
}
}
List<Integer> subsetFeatures = getSubsamplingFeatures(this.featureSubsetStrategy, this.featureSubsetRate);
node.setFeatures(subsetFeatures);
currMem += getStatsMem(subsetFeatures);
todoNodes.put(nodeIndexInGroup, node);
nodeIndexInGroup += 1;
}
masterParams.setTreeDepth(depthList);
masterParams.setTodoNodes(todoNodes);
masterParams.setSwitchToNextTree(false);
masterParams.setContinuousRunningStart(false);
masterParams.setFirstTree(this.trees.size() == 1);
LOG.info("Todo node size is {}", todoNodes.size());
}
if (this.isGBDT) {
if (masterParams.isSwitchToNextTree()) {
// send last full growth tree and current todo ROOT node tree
masterParams.setTrees(trees.subList(trees.size() - 2, trees.size()));
} else {
// only send current trees
masterParams.setTrees(trees.subList(trees.size() - 1, trees.size()));
}
}
if (this.isRF) {
// elements in todoTrees are also the same reference in this.trees, reuse the same object to save memory.
if (masterParams.getTreeDepth().size() == this.trees.size()) {
// if normal iteration
List<TreeNode> todoTrees = new ArrayList<TreeNode>();
for (int i = 0; i < trees.size(); i++) {
if (masterParams.getTreeDepth().get(i) >= 0) {
// such tree in current iteration treeDepth is not -1, add it to todoTrees.
todoTrees.add(trees.get(i));
} else {
// mock a TreeNode instance to make sure no surprise in further serialization. In fact
// meaningless.
todoTrees.add(new TreeNode(i, new Node(Node.INVALID_INDEX), 1d));
}
}
masterParams.setTrees(todoTrees);
} else {
// if last iteration without maxDepthList
masterParams.setTrees(trees);
}
}
if (this.isGBDT) {
// set tmp trees to DTOutput
masterParams.setTmpTrees(this.trees);
}
if (context.getCurrentIteration() % 100 == 0) {
// every 100 iterations do gc explicitly to avoid one case:
// mapper memory is 2048M and final in our cluster, if -Xmx is 2G, then occasionally oom issue.
// to fix this issue: 1. set -Xmx to 1800m; 2. call gc to drop unused memory at early stage.
// this is ugly and if it is stable with 1800m, this line should be removed
Thread gcThread = new Thread(new Runnable() {
@Override
public void run() {
System.gc();
}
});
gcThread.setDaemon(true);
gcThread.start();
}
// before master result, do checkpoint according to n iteration set by user
doCheckPoint(context, masterParams, context.getCurrentIteration());
LOG.debug("weightedTrainCount {}, weightedValidationCount {}, trainError {}, validationError {}", weightedTrainCount, weightedValidationCount, trainError, validationError);
return masterParams;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CorrelationMapper method setup.
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
this.dataSetDelimiter = modelConfig.getDataSetDelimiter();
this.dataPurifier = new DataPurifier(modelConfig, false);
this.isComputeAll = Boolean.valueOf(context.getConfiguration().get(Constants.SHIFU_CORRELATION_COMPUTE_ALL, "false"));
for (ColumnConfig config : columnConfigList) {
if (config.isCategorical()) {
Map<String, Integer> map = new HashMap<String, Integer>();
if (config.getBinCategory() != null) {
for (int i = 0; i < config.getBinCategory().size(); i++) {
List<String> cvals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
for (String cval : cvals) {
map.put(cval, i);
}
}
}
this.categoricalIndexMap.put(config.getColumnNum(), map);
}
}
if (modelConfig != null && modelConfig.getPosTags() != null) {
this.posTagSet = new HashSet<String>(modelConfig.getPosTags());
}
if (modelConfig != null && modelConfig.getNegTags() != null) {
this.negTagSet = new HashSet<String>(modelConfig.getNegTags());
}
if (modelConfig != null && modelConfig.getFlattenTags() != null) {
this.tagSet = new HashSet<String>(modelConfig.getFlattenTags());
}
if (modelConfig != null) {
this.tags = modelConfig.getSetTags();
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class PostTrainModelProcessor method updateColumnConfigWithBinAvgScore.
/**
* read the binary average score and update them into column list
*
* @param columnConfigList
* input column config list
* @return updated column config list
* @throws IOException
* for any io exception
*/
private List<ColumnConfig> updateColumnConfigWithBinAvgScore(List<ColumnConfig> columnConfigList) throws IOException {
List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource());
// CommonUtils.getDataScanners(pathFinder.getBinAvgScorePath(), modelConfig.getDataSet().getSource());
for (Scanner scanner : scanners) {
while (scanner.hasNextLine()) {
List<Integer> scores = new ArrayList<Integer>();
String[] raw = scanner.nextLine().split("\\|");
int columnNum = Integer.parseInt(raw[0]);
for (int i = 1; i < raw.length; i++) {
scores.add(Integer.valueOf(raw[i]));
}
ColumnConfig config = columnConfigList.get(columnNum);
config.setBinAvgScore(scores);
}
}
// release
closeScanners(scanners);
return columnConfigList;
}
Aggregations