Search in sources :

Example 1 with RegTree

use of com.tencent.angel.ml.GBDT.algo.RegTree in project angel by Tencent.

the class GBDTController method createNewTree.

// create new tree
// pull sampled features, initialize tree nodes, reset active nodes, reset instance position,
// calculate gradient
public void createNewTree() throws Exception {
    LOG.info("------Create new tree------");
    long startTime = System.currentTimeMillis();
    // 1. create new tree, initialize tree nodes and node stats
    RegTree tree = new RegTree(this.param);
    tree.initTreeNodes();
    this.currentDepth = 1;
    this.forest[this.currentTree] = tree;
    // 2. initialize feature set, if sampled, get from PS, otherwise use all the features
    if (this.param.colSample < 1) {
        // 2.1. pull the sampled features of the current tree
        PSModel featSample = model.getPSModel(this.param.sampledFeaturesName);
        IntIntVector sampleFeatureVector = (IntIntVector) featSample.getRow(this.currentTree);
        this.fSet = sampleFeatureVector.getStorage().getValues();
        calfPos();
    // this.forest[this.currentTree].fset = sampleFeatureVector.getStorage().getValues();
    } else {
        // 2.2. if use all the features, only called one
        if (null == this.fSet) {
            this.fSet = new int[this.trainDataStore.featureMeta.numFeature];
            Arrays.setAll(this.fSet, i -> i);
            this.fPos = new int[this.trainDataStore.featureMeta.numFeature];
            Arrays.setAll(this.fPos, i -> i);
        }
    }
    // 3. reset active tree nodes, set all tree nodes to inactive, set thread status to idle
    for (int nid = 0; nid < this.maxNodeNum; nid++) {
        resetActiveTNodes(nid);
    }
    // 4. set root node to active
    addActiveNode(0);
    // 5. reset instance position, set the root node's span
    this.nodePosStart[0] = 0;
    this.nodePosEnd[0] = this.instancePos.length - 1;
    for (int nid = 1; nid < this.maxNodeNum; nid++) {
        this.nodePosStart[nid] = -1;
        this.nodePosEnd[nid] = -1;
    }
    // reset position of validation instance
    Arrays.setAll(this.validInsPos, i -> 0);
    // 6. calculate gradient
    calGradPairs();
    LOG.info(String.format("Create new tree cost: %d ms", System.currentTimeMillis() - startTime));
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) RegTree(com.tencent.angel.ml.GBDT.algo.RegTree) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector)

Aggregations

RegTree (com.tencent.angel.ml.GBDT.algo.RegTree)1 IntIntVector (com.tencent.angel.ml.math2.vector.IntIntVector)1 PSModel (com.tencent.angel.ml.model.PSModel)1