Search in sources :

Example 1 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class RangeRouterUtils method splitIntDoubleVector.

public static KeyValuePart[] splitIntDoubleVector(MatrixMeta matrixMeta, IntDoubleVector vector) {
    IntDoubleVectorStorage storage = vector.getStorage();
    if (storage.isSparse()) {
        // Get keys and values
        IntDoubleSparseVectorStorage sparseStorage = (IntDoubleSparseVectorStorage) storage;
        int[] keys = sparseStorage.getIndices();
        double[] values = sparseStorage.getValues();
        return split(matrixMeta, vector.getRowId(), keys, values, false);
    } else if (storage.isDense()) {
        // Get values
        IntDoubleDenseVectorStorage denseStorage = (IntDoubleDenseVectorStorage) storage;
        double[] values = denseStorage.getValues();
        return split(matrixMeta, vector.getRowId(), values);
    } else {
        // Key and value array pair
        IntDoubleSortedVectorStorage sortStorage = (IntDoubleSortedVectorStorage) storage;
        int[] keys = sortStorage.getIndices();
        double[] values = sortStorage.getValues();
        return split(matrixMeta, vector.getRowId(), keys, values, true);
    }
}
Also used : IntDoubleSparseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage) IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) IntDoubleVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleVectorStorage) IntDoubleSortedVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleSortedVectorStorage)

Example 2 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class CompIntDoubleVectorSplitter method split.

@Override
public Map<PartitionKey, RowUpdateSplit> split(Vector vector, List<PartitionKey> parts) {
    IntDoubleVector[] vecParts = ((CompIntDoubleVector) vector).getPartitions();
    assert vecParts.length == parts.size();
    Map<PartitionKey, RowUpdateSplit> updateSplitMap = new HashMap<>(parts.size());
    for (int i = 0; i < vecParts.length; i++) {
        updateSplitMap.put(parts.get(i), new CompIntDoubleRowUpdateSplit(vector.getRowId(), vecParts[i], (int) (parts.get(i).getEndCol() - parts.get(i).getStartCol())));
    }
    return updateSplitMap;
}
Also used : HashMap(java.util.HashMap) PartitionKey(com.tencent.angel.PartitionKey) CompIntDoubleRowUpdateSplit(com.tencent.angel.psagent.matrix.oplog.cache.CompIntDoubleRowUpdateSplit) RowUpdateSplit(com.tencent.angel.psagent.matrix.oplog.cache.RowUpdateSplit) CompIntDoubleVector(com.tencent.angel.ml.math2.vector.CompIntDoubleVector) CompIntDoubleRowUpdateSplit(com.tencent.angel.psagent.matrix.oplog.cache.CompIntDoubleRowUpdateSplit) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) CompIntDoubleVector(com.tencent.angel.ml.math2.vector.CompIntDoubleVector)

Example 3 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class GBDTController method updateNodeGradStats.

// update node's grad stats on PS
// called during splitting in GradHistHelper, update the grad stats of children nodes after finding the best split
// the root node's stats is updated by leader worker
public void updateNodeGradStats(int nid, GradStats gradStats) throws Exception {
    LOG.debug(String.format("Update gradStats of node[%d]: sumGrad[%f], sumHess[%f]", nid, gradStats.sumGrad, gradStats.sumHess));
    // 1. create the update
    IntDoubleVector vec = new IntDoubleVector(2 * this.activeNode.length, new IntDoubleDenseVectorStorage(2 * this.activeNode.length));
    vec.set(nid, gradStats.sumGrad);
    vec.set(nid + this.activeNode.length, gradStats.sumHess);
    // 2. push the update to PS
    PSModel nodeGradStats = this.model.getPSModel(this.param.nodeGradStatsName);
    nodeGradStats.increment(this.currentTree, vec);
}
Also used : IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 4 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class GBDTController method getSketch.

// pull the global sketch from PS, only called once by each worker
public void getSketch() throws Exception {
    PSModel sketch = model.getPSModel(this.param.sketchName);
    LOG.info("------Get sketch from PS------");
    long startTime = System.currentTimeMillis();
    IntDoubleVector sketchVector = (IntDoubleVector) sketch.getRow(0);
    LOG.info(String.format("Get sketch cost: %d ms", System.currentTimeMillis() - startTime));
    for (int i = 0; i < sketchVector.getDim(); i++) {
        this.sketches[i] = (float) sketchVector.get(i);
    }
    // number of categorical feature
    for (int i = 0; i < cateFeatList.size(); i++) {
        int fid = cateFeatList.get(i);
        int start = fid * this.param.numSplit;
        int splitNum = 1;
        for (int j = 0; j < this.param.numSplit; j++) {
            if (this.sketches[start + j + 1] > this.sketches[start + j]) {
                splitNum++;
            } else
                break;
        }
        this.cateFeatNum.put(fid, splitNum);
    }
    LOG.info("Number of splits of categorical features: " + this.cateFeatNum.entrySet().toString());
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 5 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class GBDTController method afterSplit.

public void afterSplit() throws Exception {
    LOG.info("------After split------");
    long startTime = System.currentTimeMillis();
    // 1. get split feature
    PSModel splitFeatModel = model.getPSModel(this.param.splitFeaturesName);
    IntIntVector splitFeatureVec = (IntIntVector) splitFeatModel.getRow(currentTree);
    // 2. get split value
    PSModel splitValueModel = model.getPSModel(this.param.splitValuesName);
    IntDoubleVector splitValueVec = (IntDoubleVector) splitValueModel.getRow(currentTree);
    // 3. get split gain
    PSModel splitGainModel = model.getPSModel(this.param.splitGainsName);
    IntDoubleVector splitGainVec = (IntDoubleVector) splitGainModel.getRow(currentTree);
    // 4. get node weight
    PSModel nodeGradStatsModel = model.getPSModel(this.param.nodeGradStatsName);
    IntDoubleVector nodeGradStatsVec = (IntDoubleVector) nodeGradStatsModel.getRow(currentTree);
    LOG.info(String.format("Get split result from PS cost %d ms", System.currentTimeMillis() - startTime));
    // 5. split node
    LOG.debug(String.format("Split active node: %s", Arrays.toString(this.activeNode)));
    int[] preActiveNode = this.activeNode.clone();
    for (int nid = 0; nid < this.maxNodeNum; nid++) {
        if (preActiveNode[nid] == 1) {
            // update local replica
            this.splitFeats[nid] = splitFeatureVec.get(nid);
            this.splitValues[nid] = splitValueVec.get(nid);
            // create AfterSplit task
            this.activeNodeStat[nid].set(1);
            AfterSplitThread t = new AfterSplitThread(this, nid, splitFeatureVec, splitValueVec, splitGainVec, nodeGradStatsVec);
            this.threadPool.submit(t);
        }
    }
    // 2. check thread stats, if all threads finish, return
    boolean hasRunning = true;
    while (hasRunning) {
        hasRunning = false;
        for (int nid = 0; nid < this.maxNodeNum; nid++) {
            int stat = this.activeNodeStat[nid].get();
            if (stat == 1) {
                hasRunning = true;
                break;
            }
        }
        if (hasRunning) {
            LOG.debug("current has running thread");
        }
    }
    updateValidInsPos();
    finishCurrentDepth();
    LOG.info(String.format("After split cost: %d ms", System.currentTimeMillis() - startTime));
    // 6. clock
    Set<String> needFlushMatrixSet = new HashSet<String>(4);
    needFlushMatrixSet.add(this.param.splitFeaturesName);
    needFlushMatrixSet.add(this.param.splitValuesName);
    needFlushMatrixSet.add(this.param.splitGainsName);
    needFlushMatrixSet.add(this.param.nodeGradStatsName);
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Aggregations

IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)95 ObjectIterator (it.unimi.dsi.fastutil.objects.ObjectIterator)55 IntDoubleVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleVectorStorage)51 Int2DoubleMap (it.unimi.dsi.fastutil.ints.Int2DoubleMap)51 CompIntDoubleVector (com.tencent.angel.ml.math2.vector.CompIntDoubleVector)40 IntDoubleSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage)32 IntFloatVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatVectorStorage)32 IntIntVectorStorage (com.tencent.angel.ml.math2.storage.IntIntVectorStorage)32 IntLongVectorStorage (com.tencent.angel.ml.math2.storage.IntLongVectorStorage)32 LongDoubleVectorStorage (com.tencent.angel.ml.math2.storage.LongDoubleVectorStorage)30 LongFloatVectorStorage (com.tencent.angel.ml.math2.storage.LongFloatVectorStorage)30 LongIntVectorStorage (com.tencent.angel.ml.math2.storage.LongIntVectorStorage)30 LongLongVectorStorage (com.tencent.angel.ml.math2.storage.LongLongVectorStorage)30 Storage (com.tencent.angel.ml.math2.storage.Storage)30 IntDoubleSortedVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleSortedVectorStorage)26 IntDoubleDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage)23 IntFloatSortedVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatSortedVectorStorage)20 IntFloatSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatSparseVectorStorage)20 IntIntSortedVectorStorage (com.tencent.angel.ml.math2.storage.IntIntSortedVectorStorage)20 IntIntSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntIntSparseVectorStorage)20