Search in sources :

Example 71 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class MergeUtils method combineServerIntDoubleRowSplits.

private static Vector combineServerIntDoubleRowSplits(List<ServerRow> rowSplits, MatrixMeta matrixMeta, int rowIndex) {
    int colNum = (int) matrixMeta.getColNum();
    int elemNum = 0;
    int size = rowSplits.size();
    for (int i = 0; i < size; i++) {
        elemNum += rowSplits.get(i).size();
    }
    IntDoubleVector row;
    if (matrixMeta.isHash()) {
        row = VFactory.sparseDoubleVector(colNum, elemNum);
    } else {
        if (elemNum >= (int) (storageConvFactor * colNum)) {
            row = VFactory.denseDoubleVector(colNum);
        } else {
            row = VFactory.sparseDoubleVector(colNum, elemNum);
        }
    }
    row.setMatrixId(matrixMeta.getId());
    row.setRowId(rowIndex);
    Collections.sort(rowSplits, serverRowComp);
    for (int i = 0; i < size; i++) {
        if (rowSplits.get(i) == null) {
            continue;
        }
        ((ServerIntDoubleRow) rowSplits.get(i)).mergeTo(row);
    }
    return row;
}
Also used : ServerIntDoubleRow(com.tencent.angel.ps.storage.vector.ServerIntDoubleRow) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 72 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class MergeUtils method combineIntDoubleIndexRowSplits.

// //////////////////////////////////////////////////////////////////////////////
// Combine Int key Double value vector
// //////////////////////////////////////////////////////////////////////////////
public static Vector combineIntDoubleIndexRowSplits(int matrixId, int rowId, int resultSize, KeyPart[] keyParts, ValuePart[] valueParts, MatrixMeta matrixMeta) {
    IntDoubleVector vector = VFactory.sparseDoubleVector((int) matrixMeta.getColNum(), resultSize);
    for (int i = 0; i < keyParts.length; i++) {
        mergeTo(vector, keyParts[i], (DoubleValuesPart) valueParts[i]);
    }
    vector.setRowId(rowId);
    vector.setMatrixId(matrixId);
    return vector;
}
Also used : IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 73 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class GBDTController method updateLeafPreds.

public void updateLeafPreds() throws Exception {
    LOG.info("------Update leaf node predictions------");
    long startTime = System.currentTimeMillis();
    Set<String> needFlushMatrixSet = new HashSet<String>(1);
    if (taskContext.getTaskIndex() == 0) {
        int nodeNum = this.forest[currentTree].nodes.size();
        IntDoubleVector vec = new IntDoubleVector(this.maxNodeNum, new IntDoubleDenseVectorStorage(this.maxNodeNum));
        for (int nid = 0; nid < nodeNum; nid++) {
            if (null != this.forest[currentTree].nodes.get(nid) && this.forest[currentTree].nodes.get(nid).isLeaf()) {
                float weight = this.forest[currentTree].nodes.get(nid).getLeafValue();
                LOG.debug(String.format("Leaf weight of node[%d]: %f", nid, weight));
                vec.set(nid, weight);
            }
        }
        PSModel nodePreds = this.model.getPSModel(this.param.nodePredsName);
        nodePreds.increment(this.currentTree, vec);
        // the leader task adds node prediction to flush list
        needFlushMatrixSet.add(this.param.nodePredsName);
    }
    clockAllMatrix(needFlushMatrixSet, true);
    LOG.info(String.format("Update leaf node predictions cost: %d ms", System.currentTimeMillis() - startTime));
}
Also used : IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 74 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class GBDTController method findSplit.

// find split
public void findSplit() throws Exception {
    LOG.info("------Find split------");
    long startTime = System.currentTimeMillis();
    // 1. find responsible tree node, using RR scheme
    List<Integer> responsibleTNode = new ArrayList<>();
    int activeTNodeNum = 0;
    for (int nid = 0; nid < this.activeNode.length; nid++) {
        int isActive = this.activeNode[nid];
        if (isActive == 1) {
            if (this.taskContext.getTaskIndex() == activeTNodeNum) {
                responsibleTNode.add(nid);
            }
            if (++activeTNodeNum >= taskContext.getTotalTaskNum()) {
                activeTNodeNum = 0;
            }
        }
    }
    int[] tNodeId = Maths.intList2Arr(responsibleTNode);
    LOG.info(String.format("Task[%d] responsible tree node: %s", this.taskContext.getTaskId().getIndex(), responsibleTNode.toString()));
    // 2. pull gradient histogram
    // the updated indices of the parameter on PS
    int[] updatedIndices = new int[tNodeId.length];
    // the updated split features
    int[] updatedSplitFid = new int[tNodeId.length];
    // the updated split value
    double[] updatedSplitFvalue = new double[tNodeId.length];
    // the updated split gain
    double[] updatedSplitGain = new double[tNodeId.length];
    boolean isServerSplit = taskContext.getConf().getBoolean(MLConf.ML_GBDT_SERVER_SPLIT(), MLConf.DEFAULT_ML_GBDT_SERVER_SPLIT());
    int splitNum = taskContext.getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
    for (int i = 0; i < tNodeId.length; i++) {
        int nid = tNodeId[i];
        LOG.debug(String.format("Task[%d] find best split of tree node: %d", this.taskContext.getTaskIndex(), nid));
        // 2.1. get the name of this node's gradient histogram on PS
        String gradHistName = this.param.gradHistNamePrefix + nid;
        // 2.2. pull the histogram
        long pullStartTime = System.currentTimeMillis();
        PSModel histMat = model.getPSModel(gradHistName);
        IntDoubleVector histogram = null;
        SplitEntry splitEntry = null;
        if (isServerSplit) {
            int matrixId = histMat.getMatrixId();
            GBDTGradHistGetRowFunc func = new GBDTGradHistGetRowFunc(new HistAggrParam(matrixId, 0, param.numSplit, param.minChildWeight, param.regAlpha, param.regLambda));
            splitEntry = ((GBDTGradHistGetRowResult) histMat.get(func)).getSplitEntry();
        } else {
            histogram = (IntDoubleVector) histMat.getRow(0);
            LOG.debug("Get grad histogram without server split mode, histogram size" + histogram.getDim());
        }
        LOG.info(String.format("Pull histogram from PS cost %d ms", System.currentTimeMillis() - pullStartTime));
        GradHistHelper histHelper = new GradHistHelper(this, nid);
        // 2.3. find best split result of this tree node
        if (this.param.isServerSplit) {
            // 2.3.1 using server split
            if (splitEntry.getFid() != -1) {
                int trueSplitFid = this.fSet[splitEntry.getFid()];
                int splitIdx = (int) splitEntry.getFvalue();
                float trueSplitValue = this.sketches[trueSplitFid * this.param.numSplit + splitIdx];
                LOG.info(String.format("Best split of node[%d]: feature[%d], value[%f], " + "true feature[%d], true value[%f], losschg[%f]", nid, splitEntry.getFid(), splitEntry.getFvalue(), trueSplitFid, trueSplitValue, splitEntry.getLossChg()));
                splitEntry.setFid(trueSplitFid);
                splitEntry.setFvalue(trueSplitValue);
            }
            // update the grad stats of the root node on PS, only called once by leader worker
            if (nid == 0) {
                GradStats rootStats = new GradStats(splitEntry.leftGradStat);
                rootStats.add(splitEntry.rightGradStat);
                this.updateNodeGradStats(nid, rootStats);
            }
            // update the grad stats of children node
            if (splitEntry.fid != -1) {
                // update the left child
                this.updateNodeGradStats(2 * nid + 1, splitEntry.leftGradStat);
                // update the right child
                this.updateNodeGradStats(2 * nid + 2, splitEntry.rightGradStat);
            }
            // 2.3.2 the updated split result (tree node/feature/value/gain) on PS,
            updatedIndices[i] = nid;
            updatedSplitFid[i] = splitEntry.fid;
            updatedSplitFvalue[i] = splitEntry.fvalue;
            updatedSplitGain[i] = splitEntry.lossChg;
        } else {
            // 2.3.3 otherwise, the returned histogram contains the gradient info
            splitEntry = histHelper.findBestSplit(histogram);
            LOG.info(String.format("Best split of node[%d]: feature[%d], value[%f], losschg[%f]", nid, splitEntry.getFid(), splitEntry.getFvalue(), splitEntry.getLossChg()));
            // 2.3.4 the updated split result (tree node/feature/value/gain) on PS,
            updatedIndices[i] = nid;
            updatedSplitFid[i] = splitEntry.fid;
            updatedSplitFvalue[i] = splitEntry.fvalue;
            updatedSplitGain[i] = splitEntry.lossChg;
        }
        // 2.3.5 reset this tree node's gradient histogram to 0
        histMat.zero();
    }
    // 3. push split feature to PS
    IntIntVector splitFeatureVector = new IntIntVector(this.activeNode.length, new IntIntDenseVectorStorage(this.activeNode.length));
    // 4. push split value to PS
    IntDoubleVector splitValueVector = new IntDoubleVector(this.activeNode.length, new IntDoubleDenseVectorStorage(this.activeNode.length));
    // 5. push split gain to PS
    IntDoubleVector splitGainVector = new IntDoubleVector(this.activeNode.length, new IntDoubleDenseVectorStorage(this.activeNode.length));
    for (int i = 0; i < updatedIndices.length; i++) {
        splitFeatureVector.set(updatedIndices[i], updatedSplitFid[i]);
        splitValueVector.set(updatedIndices[i], updatedSplitFvalue[i]);
        splitGainVector.set(updatedIndices[i], updatedSplitGain[i]);
    }
    PSModel splitFeat = model.getPSModel(this.param.splitFeaturesName);
    splitFeat.increment(this.currentTree, splitFeatureVector);
    PSModel splitValue = model.getPSModel(this.param.splitValuesName);
    splitValue.increment(this.currentTree, splitValueVector);
    PSModel splitGain = model.getPSModel(this.param.splitGainsName);
    splitGain.increment(this.currentTree, splitGainVector);
    // 6. set phase to AFTER_SPLIT
    // this.phase = GBDTPhase.AFTER_SPLIT;
    LOG.info(String.format("Find split cost: %d ms", System.currentTimeMillis() - startTime));
    // clock
    Set<String> needFlushMatrixSet = new HashSet<String>(3);
    needFlushMatrixSet.add(this.param.splitFeaturesName);
    needFlushMatrixSet.add(this.param.splitValuesName);
    needFlushMatrixSet.add(this.param.splitGainsName);
    needFlushMatrixSet.add(this.param.nodeGradStatsName);
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry) HistAggrParam(com.tencent.angel.ml.GBDT.psf.HistAggrParam) GBDTGradHistGetRowFunc(com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) IntIntDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage)

Example 75 with IntDoubleVector

use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.

the class GBDTController method runActiveNode.

public void runActiveNode() throws Exception {
    LOG.info("------Run active node------");
    long startTime = System.currentTimeMillis();
    Set<String> needFlushMatrixSet = new HashSet<String>();
    // 1. decide nodes that should be calculated
    Set<Integer> calNodes = new HashSet<>();
    Set<Integer> subNodes = new HashSet<>();
    // 2. decide calculated and subtracted tree nodes
    for (int nid = 0; nid < this.maxNodeNum; nid++) {
        if (this.activeNode[nid] == 1) {
            if (nid == 0) {
                calNodes.add(nid);
            } else {
                int parentNid = (nid - 1) / 2;
                int siblingNid = 4 * parentNid + 3 - nid;
                int sampleNum = this.nodePosEnd[nid] - this.nodePosStart[nid] + 1;
                int siblingSampleNum = this.nodePosEnd[siblingNid] - this.nodePosStart[siblingNid] + 1;
                boolean ltSibling = sampleNum < siblingSampleNum || (sampleNum == siblingSampleNum && nid < siblingNid);
                if (ltSibling) {
                    calNodes.add(nid);
                    subNodes.add(siblingNid);
                } else {
                    calNodes.add(siblingNid);
                    subNodes.add(nid);
                }
            }
        }
    }
    // 3. calculate threads
    Map<Integer, List<Future<Boolean>>> calFutures = new HashMap<>();
    for (int nid : calNodes) {
        histCache[nid] = new IntDoubleVector(this.fSet.length * 2 * this.param.numSplit, new IntDoubleDenseVectorStorage(new double[this.param.numFeature * 2 * this.param.numSplit]));
        calFutures.put(nid, new ArrayList<>());
        int nodeStart = this.nodePosStart[nid];
        int nodeEnd = this.nodePosEnd[nid];
        int batchNum = (nodeEnd - nodeStart + 1) / this.param.batchSize + ((nodeEnd - nodeStart + 1) % this.param.batchSize == 0 ? 0 : 1);
        LOG.info(String.format("Node[%d], start[%d], end[%d], batch[%d]", nid, nodeStart, nodeEnd, batchNum));
        for (int batch = 0; batch < batchNum; batch++) {
            int start = nodeStart + batch * this.param.batchSize;
            int end = nodeStart + (batch + 1) * this.param.batchSize;
            if (end > nodeEnd) {
                end = nodeEnd;
            }
            LOG.info(String.format("Calculate thread: nid[%d], start[%d], end[%d]", nid, start, end));
            Future<Boolean> future = this.threadPool.submit(new HistCalThread(this, nid, start, end));
            calFutures.get(nid).add(future);
        }
    }
    // wait until all threads finish
    for (int nid : calNodes) {
        for (Future<Boolean> future : calFutures.get(nid)) {
            future.get();
        }
    }
    // 4. subtract threads
    Map<Integer, Future<Boolean>> subFutures = new HashMap<>();
    for (int nid : subNodes) {
        int parentId = (nid - 1) / 2;
        histCache[nid] = histCache[parentId].clone();
        LOG.info(String.format("Subtract thread: nid[%d]", nid));
        Future<Boolean> future = this.threadPool.submit(new HistSubThread(this, nid));
        subFutures.put(nid, future);
    }
    // wait until all threads finish
    for (int nid : subNodes) {
        subFutures.get(nid).get();
    }
    // 5. send histograms to PS
    Set<Integer> pushNodes = new HashSet<>(calNodes);
    pushNodes.addAll(subNodes);
    int bytesPerItem = this.taskContext.getConf().getInt(MLConf.ANGEL_COMPRESS_BYTES(), MLConf.DEFAULT_ANGEL_COMPRESS_BYTES());
    if (bytesPerItem < 1 || bytesPerItem > 8) {
        LOG.info("Invalid compress configuration: " + bytesPerItem + ", it should be [1,8].");
        bytesPerItem = MLConf.DEFAULT_ANGEL_COMPRESS_BYTES();
    }
    for (int nid : pushNodes) {
        pushHistogram(nid, bytesPerItem);
        needFlushMatrixSet.add(this.param.gradHistNamePrefix + nid);
    }
    // 6. update histogram cache
    for (int nid : calNodes) {
        if (nid == 0)
            break;
        int parentId = (nid - 1) / 2;
        this.histCache[parentId] = null;
    }
    LOG.info(String.format("Run active node cost: %d ms", System.currentTimeMillis() - startTime));
    // clock
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) Future(java.util.concurrent.Future)

Aggregations

IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)95 ObjectIterator (it.unimi.dsi.fastutil.objects.ObjectIterator)55 IntDoubleVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleVectorStorage)51 Int2DoubleMap (it.unimi.dsi.fastutil.ints.Int2DoubleMap)51 CompIntDoubleVector (com.tencent.angel.ml.math2.vector.CompIntDoubleVector)40 IntDoubleSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage)32 IntFloatVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatVectorStorage)32 IntIntVectorStorage (com.tencent.angel.ml.math2.storage.IntIntVectorStorage)32 IntLongVectorStorage (com.tencent.angel.ml.math2.storage.IntLongVectorStorage)32 LongDoubleVectorStorage (com.tencent.angel.ml.math2.storage.LongDoubleVectorStorage)30 LongFloatVectorStorage (com.tencent.angel.ml.math2.storage.LongFloatVectorStorage)30 LongIntVectorStorage (com.tencent.angel.ml.math2.storage.LongIntVectorStorage)30 LongLongVectorStorage (com.tencent.angel.ml.math2.storage.LongLongVectorStorage)30 Storage (com.tencent.angel.ml.math2.storage.Storage)30 IntDoubleSortedVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleSortedVectorStorage)26 IntDoubleDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage)23 IntFloatSortedVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatSortedVectorStorage)20 IntFloatSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatSparseVectorStorage)20 IntIntSortedVectorStorage (com.tencent.angel.ml.math2.storage.IntIntSortedVectorStorage)20 IntIntSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntIntSparseVectorStorage)20