Search in sources :

Example 1 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method updateNodeGradStats.

// update node's grad stats on PS
// called during splitting in GradHistHelper, update the grad stats of children nodes after finding the best split
// the root node's stats is updated by leader worker
public void updateNodeGradStats(int nid, GradStats gradStats) throws Exception {
    LOG.debug(String.format("Update gradStats of node[%d]: sumGrad[%f], sumHess[%f]", nid, gradStats.sumGrad, gradStats.sumHess));
    // 1. create the update
    IntDoubleVector vec = new IntDoubleVector(2 * this.activeNode.length, new IntDoubleDenseVectorStorage(2 * this.activeNode.length));
    vec.set(nid, gradStats.sumGrad);
    vec.set(nid + this.activeNode.length, gradStats.sumHess);
    // 2. push the update to PS
    PSModel nodeGradStats = this.model.getPSModel(this.param.nodeGradStatsName);
    nodeGradStats.increment(this.currentTree, vec);
}
Also used : IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 2 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method getSketch.

// pull the global sketch from PS, only called once by each worker
public void getSketch() throws Exception {
    PSModel sketch = model.getPSModel(this.param.sketchName);
    LOG.info("------Get sketch from PS------");
    long startTime = System.currentTimeMillis();
    IntDoubleVector sketchVector = (IntDoubleVector) sketch.getRow(0);
    LOG.info(String.format("Get sketch cost: %d ms", System.currentTimeMillis() - startTime));
    for (int i = 0; i < sketchVector.getDim(); i++) {
        this.sketches[i] = (float) sketchVector.get(i);
    }
    // number of categorical feature
    for (int i = 0; i < cateFeatList.size(); i++) {
        int fid = cateFeatList.get(i);
        int start = fid * this.param.numSplit;
        int splitNum = 1;
        for (int j = 0; j < this.param.numSplit; j++) {
            if (this.sketches[start + j + 1] > this.sketches[start + j]) {
                splitNum++;
            } else
                break;
        }
        this.cateFeatNum.put(fid, splitNum);
    }
    LOG.info("Number of splits of categorical features: " + this.cateFeatNum.entrySet().toString());
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 3 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method afterSplit.

public void afterSplit() throws Exception {
    LOG.info("------After split------");
    long startTime = System.currentTimeMillis();
    // 1. get split feature
    PSModel splitFeatModel = model.getPSModel(this.param.splitFeaturesName);
    IntIntVector splitFeatureVec = (IntIntVector) splitFeatModel.getRow(currentTree);
    // 2. get split value
    PSModel splitValueModel = model.getPSModel(this.param.splitValuesName);
    IntDoubleVector splitValueVec = (IntDoubleVector) splitValueModel.getRow(currentTree);
    // 3. get split gain
    PSModel splitGainModel = model.getPSModel(this.param.splitGainsName);
    IntDoubleVector splitGainVec = (IntDoubleVector) splitGainModel.getRow(currentTree);
    // 4. get node weight
    PSModel nodeGradStatsModel = model.getPSModel(this.param.nodeGradStatsName);
    IntDoubleVector nodeGradStatsVec = (IntDoubleVector) nodeGradStatsModel.getRow(currentTree);
    LOG.info(String.format("Get split result from PS cost %d ms", System.currentTimeMillis() - startTime));
    // 5. split node
    LOG.debug(String.format("Split active node: %s", Arrays.toString(this.activeNode)));
    int[] preActiveNode = this.activeNode.clone();
    for (int nid = 0; nid < this.maxNodeNum; nid++) {
        if (preActiveNode[nid] == 1) {
            // update local replica
            this.splitFeats[nid] = splitFeatureVec.get(nid);
            this.splitValues[nid] = splitValueVec.get(nid);
            // create AfterSplit task
            this.activeNodeStat[nid].set(1);
            AfterSplitThread t = new AfterSplitThread(this, nid, splitFeatureVec, splitValueVec, splitGainVec, nodeGradStatsVec);
            this.threadPool.submit(t);
        }
    }
    // 2. check thread stats, if all threads finish, return
    boolean hasRunning = true;
    while (hasRunning) {
        hasRunning = false;
        for (int nid = 0; nid < this.maxNodeNum; nid++) {
            int stat = this.activeNodeStat[nid].get();
            if (stat == 1) {
                hasRunning = true;
                break;
            }
        }
        if (hasRunning) {
            LOG.debug("current has running thread");
        }
    }
    updateValidInsPos();
    finishCurrentDepth();
    LOG.info(String.format("After split cost: %d ms", System.currentTimeMillis() - startTime));
    // 6. clock
    Set<String> needFlushMatrixSet = new HashSet<String>(4);
    needFlushMatrixSet.add(this.param.splitFeaturesName);
    needFlushMatrixSet.add(this.param.splitValuesName);
    needFlushMatrixSet.add(this.param.splitGainsName);
    needFlushMatrixSet.add(this.param.nodeGradStatsName);
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 4 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method createSketch.

// create data sketch, push candidate split value to PS
public void createSketch() throws Exception {
    PSModel sketch = model.getPSModel(this.param.sketchName);
    PSModel cateFeat = model.getPSModel(this.param.cateFeatureName);
    if (taskContext.getTaskIndex() == 0) {
        LOG.info("------Create sketch------");
        long startTime = System.currentTimeMillis();
        IntDoubleVector sketchVec = new IntDoubleVector(this.param.numFeature * this.param.numSplit, new IntDoubleDenseVectorStorage(new double[this.param.numFeature * this.param.numSplit]));
        IntDoubleVector cateFeatVec = null;
        if (!this.cateFeatList.isEmpty()) {
            cateFeatVec = new IntDoubleVector(this.cateFeatList.size() * this.param.numSplit, new IntDoubleDenseVectorStorage(new double[this.cateFeatList.size() * this.param.numSplit]));
        }
        // 1. calculate candidate split value
        float[][] splits = TYahooSketchSplit.getSplitValue(this.trainDataStore, this.param.numSplit, this.cateFeatList);
        if (splits.length == this.param.numFeature && splits[0].length == this.param.numSplit) {
            for (int fid = 0; fid < splits.length; fid++) {
                if (cateFeatList.contains(fid)) {
                    continue;
                }
                for (int j = 0; j < splits[fid].length; j++) {
                    sketchVec.set(fid * this.param.numSplit + j, splits[fid][j]);
                }
            }
        } else {
            LOG.error("Incompatible sketches size.");
        }
        // categorical features
        if (!this.cateFeatList.isEmpty()) {
            Collections.sort(this.cateFeatList);
            for (int i = 0; i < this.cateFeatList.size(); i++) {
                int fid = this.cateFeatList.get(i);
                int start = i * this.param.numSplit;
                for (int j = 0; j < splits[fid].length; j++) {
                    if (splits[fid][j] == 0 && j > 0)
                        break;
                    cateFeatVec.set(start + j, splits[fid][j]);
                }
            }
        }
        // 2. push local sketch to PS
        sketch.increment(0, sketchVec);
        if (null != cateFeatVec) {
            cateFeat.increment(this.taskContext.getTaskIndex(), cateFeatVec);
        }
        LOG.info(String.format("Create sketch cost: %d ms", System.currentTimeMillis() - startTime));
    }
    Set<String> needFlushMatrixSet = new HashSet<String>(1);
    needFlushMatrixSet.add(this.param.sketchName);
    needFlushMatrixSet.add(this.param.cateFeatureName);
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 5 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method createNewTree.

// create new tree
// pull sampled features, initialize tree nodes, reset active nodes, reset instance position,
// calculate gradient
public void createNewTree() throws Exception {
    LOG.info("------Create new tree------");
    long startTime = System.currentTimeMillis();
    // 1. create new tree, initialize tree nodes and node stats
    RegTree tree = new RegTree(this.param);
    tree.initTreeNodes();
    this.currentDepth = 1;
    this.forest[this.currentTree] = tree;
    // 2. initialize feature set, if sampled, get from PS, otherwise use all the features
    if (this.param.colSample < 1) {
        // 2.1. pull the sampled features of the current tree
        PSModel featSample = model.getPSModel(this.param.sampledFeaturesName);
        IntIntVector sampleFeatureVector = (IntIntVector) featSample.getRow(this.currentTree);
        this.fSet = sampleFeatureVector.getStorage().getValues();
        calfPos();
    // this.forest[this.currentTree].fset = sampleFeatureVector.getStorage().getValues();
    } else {
        // 2.2. if use all the features, only called one
        if (null == this.fSet) {
            this.fSet = new int[this.trainDataStore.featureMeta.numFeature];
            Arrays.setAll(this.fSet, i -> i);
            this.fPos = new int[this.trainDataStore.featureMeta.numFeature];
            Arrays.setAll(this.fPos, i -> i);
        }
    }
    // 3. reset active tree nodes, set all tree nodes to inactive, set thread status to idle
    for (int nid = 0; nid < this.maxNodeNum; nid++) {
        resetActiveTNodes(nid);
    }
    // 4. set root node to active
    addActiveNode(0);
    // 5. reset instance position, set the root node's span
    this.nodePosStart[0] = 0;
    this.nodePosEnd[0] = this.instancePos.length - 1;
    for (int nid = 1; nid < this.maxNodeNum; nid++) {
        this.nodePosStart[nid] = -1;
        this.nodePosEnd[nid] = -1;
    }
    // reset position of validation instance
    Arrays.setAll(this.validInsPos, i -> 0);
    // 6. calculate gradient
    calGradPairs();
    LOG.info(String.format("Create new tree cost: %d ms", System.currentTimeMillis() - startTime));
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) RegTree(com.tencent.angel.ml.GBDT.algo.RegTree) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector)

Aggregations

PSModel (com.tencent.angel.ml.model.PSModel)11 IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)7 IntDoubleDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage)4 IntIntVector (com.tencent.angel.ml.math2.vector.IntIntVector)4 IntIntDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage)2 AngelException (com.tencent.angel.exception.AngelException)1 RegTree (com.tencent.angel.ml.GBDT.algo.RegTree)1 SplitEntry (com.tencent.angel.ml.GBDT.algo.tree.SplitEntry)1 GBDTGradHistGetRowFunc (com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc)1 HistAggrParam (com.tencent.angel.ml.GBDT.psf.HistAggrParam)1 IntDoubleSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage)1 MatrixContext (com.tencent.angel.ml.matrix.MatrixContext)1 QuantifyDoubleFunc (com.tencent.angel.ml.psf.compress.QuantifyDoubleFunc)1 MatrixSaveContext (com.tencent.angel.model.MatrixSaveContext)1 ModelSaveContext (com.tencent.angel.model.ModelSaveContext)1 RowIdColIdValueTextRowFormat (com.tencent.angel.model.output.format.RowIdColIdValueTextRowFormat)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1