Search in sources :

Example 6 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class AngelClient method saveModel.

@SuppressWarnings("rawtypes")
@Override
public void saveModel(MLModel model) throws AngelException {
    if (master == null) {
        throw new AngelException("parameter servers are not started, you must execute startPSServer first!!");
    }
    Map<String, PSModel> psModels = model.getPSModels();
    ModelSaveContext saveContext = new ModelSaveContext();
    for (Map.Entry<String, PSModel> entry : psModels.entrySet()) {
        MatrixContext context = entry.getValue().getContext();
        String savePath = context.getAttributes().get(MatrixConf.MATRIX_SAVE_PATH);
        if (savePath != null) {
            saveContext.addMatrix(new MatrixSaveContext(context.getName(), conf.get("OUT_FORMAT_CLASS", RowIdColIdValueTextRowFormat.class.getName())));
        }
    }
    saveContext.setSavePath(conf.get(AngelConf.ANGEL_JOB_OUTPUT_PATH));
    save(saveContext);
    LOG.info("save is finish");
}
Also used : AngelException(com.tencent.angel.exception.AngelException) MatrixContext(com.tencent.angel.ml.matrix.MatrixContext) PSModel(com.tencent.angel.ml.model.PSModel) RowIdColIdValueTextRowFormat(com.tencent.angel.model.output.format.RowIdColIdValueTextRowFormat) MatrixSaveContext(com.tencent.angel.model.MatrixSaveContext) ModelSaveContext(com.tencent.angel.model.ModelSaveContext) Map(java.util.Map) LinkedHashMap(java.util.LinkedHashMap)

Example 7 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method mergeCateFeatSketch.

public void mergeCateFeatSketch() throws Exception {
    LOG.info("------Merge categorical features------");
    Set<String> needFlushMatrixSet = new HashSet<String>(1);
    // the leader worker
    if (!this.cateFeatList.isEmpty() && this.taskContext.getTaskIndex() == 0) {
        PSModel cateFeat = model.getPSModel(this.param.cateFeatureName);
        PSModel sketch = model.getPSModel(this.param.sketchName);
        Set<Double>[] featSet = new HashSet[cateFeatList.size()];
        for (int i = 0; i < cateFeatList.size(); i++) {
            featSet[i] = new HashSet<>();
        }
        int workerNum = this.taskContext.getConf().getInt(AngelConf.ANGEL_WORKERGROUP_ACTUAL_NUM, 1);
        // merge categorical features
        for (int worker = 0; worker < workerNum; worker++) {
            IntDoubleVector vec = (IntDoubleVector) cateFeat.getRow(worker);
            for (int i = 0; i < cateFeatList.size(); i++) {
                int fid = cateFeatList.get(i);
                int start = i * this.param.numSplit;
                for (int j = 0; j < this.param.numSplit; j++) {
                    double fvalue = vec.get(start + j);
                    featSet[i].add(fvalue);
                }
            }
        }
        // create updates
        IntDoubleVector cateFeatVec = new IntDoubleVector(this.param.numFeature * this.param.numSplit, new IntDoubleSparseVectorStorage(this.param.numFeature * this.param.numSplit));
        for (int i = 0; i < cateFeatList.size(); i++) {
            int fid = cateFeatList.get(i);
            int start = fid * this.param.numSplit;
            List<Double> sortedValue = new ArrayList<>(featSet[i]);
            Collections.sort(sortedValue);
            assert sortedValue.size() < this.param.numSplit;
            for (int j = 0; j < sortedValue.size(); j++) {
                cateFeatVec.set(start + j, sortedValue.get(j));
            }
        }
        sketch.increment(0, cateFeatVec);
        needFlushMatrixSet.add(this.param.sketchName);
    }
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) IntDoubleSparseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage)

Example 8 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method updateLeafPreds.

public void updateLeafPreds() throws Exception {
    LOG.info("------Update leaf node predictions------");
    long startTime = System.currentTimeMillis();
    Set<String> needFlushMatrixSet = new HashSet<String>(1);
    if (taskContext.getTaskIndex() == 0) {
        int nodeNum = this.forest[currentTree].nodes.size();
        IntDoubleVector vec = new IntDoubleVector(this.maxNodeNum, new IntDoubleDenseVectorStorage(this.maxNodeNum));
        for (int nid = 0; nid < nodeNum; nid++) {
            if (null != this.forest[currentTree].nodes.get(nid) && this.forest[currentTree].nodes.get(nid).isLeaf()) {
                float weight = this.forest[currentTree].nodes.get(nid).getLeafValue();
                LOG.debug(String.format("Leaf weight of node[%d]: %f", nid, weight));
                vec.set(nid, weight);
            }
        }
        PSModel nodePreds = this.model.getPSModel(this.param.nodePredsName);
        nodePreds.increment(this.currentTree, vec);
        // the leader task adds node prediction to flush list
        needFlushMatrixSet.add(this.param.nodePredsName);
    }
    clockAllMatrix(needFlushMatrixSet, true);
    LOG.info(String.format("Update leaf node predictions cost: %d ms", System.currentTimeMillis() - startTime));
}
Also used : IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) PSModel(com.tencent.angel.ml.model.PSModel) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector)

Example 9 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method sampleFeature.

// sample feature
public void sampleFeature() throws Exception {
    LOG.info("------Sample feature------");
    PSModel featSample = model.getPSModel(this.param.sampledFeaturesName);
    Set<String> needFlushMatrixSet = new HashSet<String>(1);
    if (this.param.colSample < 1 && taskContext.getTaskIndex() == 0) {
        long startTime = System.currentTimeMillis();
        // push sampled feature set to the current tree
        if (this.param.colSample < 1) {
            int[] fset = this.trainDataStore.featureMeta.sampleCol(this.param.colSample);
            IntIntVector sampleFeatureVector = new IntIntVector(fset.length, new IntIntDenseVectorStorage(fset));
            featSample.increment(currentTree, sampleFeatureVector);
            needFlushMatrixSet.add(this.param.sampledFeaturesName);
        }
        LOG.info(String.format("Sample feature cost: %d ms", System.currentTimeMillis() - startTime));
    }
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntIntDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage)

Example 10 with PSModel

use of com.tencent.angel.ml.model.PSModel in project angel by Tencent.

the class GBDTController method findSplit.

// find split
public void findSplit() throws Exception {
    LOG.info("------Find split------");
    long startTime = System.currentTimeMillis();
    // 1. find responsible tree node, using RR scheme
    List<Integer> responsibleTNode = new ArrayList<>();
    int activeTNodeNum = 0;
    for (int nid = 0; nid < this.activeNode.length; nid++) {
        int isActive = this.activeNode[nid];
        if (isActive == 1) {
            if (this.taskContext.getTaskIndex() == activeTNodeNum) {
                responsibleTNode.add(nid);
            }
            if (++activeTNodeNum >= taskContext.getTotalTaskNum()) {
                activeTNodeNum = 0;
            }
        }
    }
    int[] tNodeId = Maths.intList2Arr(responsibleTNode);
    LOG.info(String.format("Task[%d] responsible tree node: %s", this.taskContext.getTaskId().getIndex(), responsibleTNode.toString()));
    // 2. pull gradient histogram
    // the updated indices of the parameter on PS
    int[] updatedIndices = new int[tNodeId.length];
    // the updated split features
    int[] updatedSplitFid = new int[tNodeId.length];
    // the updated split value
    double[] updatedSplitFvalue = new double[tNodeId.length];
    // the updated split gain
    double[] updatedSplitGain = new double[tNodeId.length];
    boolean isServerSplit = taskContext.getConf().getBoolean(MLConf.ML_GBDT_SERVER_SPLIT(), MLConf.DEFAULT_ML_GBDT_SERVER_SPLIT());
    int splitNum = taskContext.getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
    for (int i = 0; i < tNodeId.length; i++) {
        int nid = tNodeId[i];
        LOG.debug(String.format("Task[%d] find best split of tree node: %d", this.taskContext.getTaskIndex(), nid));
        // 2.1. get the name of this node's gradient histogram on PS
        String gradHistName = this.param.gradHistNamePrefix + nid;
        // 2.2. pull the histogram
        long pullStartTime = System.currentTimeMillis();
        PSModel histMat = model.getPSModel(gradHistName);
        IntDoubleVector histogram = null;
        SplitEntry splitEntry = null;
        if (isServerSplit) {
            int matrixId = histMat.getMatrixId();
            GBDTGradHistGetRowFunc func = new GBDTGradHistGetRowFunc(new HistAggrParam(matrixId, 0, param.numSplit, param.minChildWeight, param.regAlpha, param.regLambda));
            splitEntry = ((GBDTGradHistGetRowResult) histMat.get(func)).getSplitEntry();
        } else {
            histogram = (IntDoubleVector) histMat.getRow(0);
            LOG.debug("Get grad histogram without server split mode, histogram size" + histogram.getDim());
        }
        LOG.info(String.format("Pull histogram from PS cost %d ms", System.currentTimeMillis() - pullStartTime));
        GradHistHelper histHelper = new GradHistHelper(this, nid);
        // 2.3. find best split result of this tree node
        if (this.param.isServerSplit) {
            // 2.3.1 using server split
            if (splitEntry.getFid() != -1) {
                int trueSplitFid = this.fSet[splitEntry.getFid()];
                int splitIdx = (int) splitEntry.getFvalue();
                float trueSplitValue = this.sketches[trueSplitFid * this.param.numSplit + splitIdx];
                LOG.info(String.format("Best split of node[%d]: feature[%d], value[%f], " + "true feature[%d], true value[%f], losschg[%f]", nid, splitEntry.getFid(), splitEntry.getFvalue(), trueSplitFid, trueSplitValue, splitEntry.getLossChg()));
                splitEntry.setFid(trueSplitFid);
                splitEntry.setFvalue(trueSplitValue);
            }
            // update the grad stats of the root node on PS, only called once by leader worker
            if (nid == 0) {
                GradStats rootStats = new GradStats(splitEntry.leftGradStat);
                rootStats.add(splitEntry.rightGradStat);
                this.updateNodeGradStats(nid, rootStats);
            }
            // update the grad stats of children node
            if (splitEntry.fid != -1) {
                // update the left child
                this.updateNodeGradStats(2 * nid + 1, splitEntry.leftGradStat);
                // update the right child
                this.updateNodeGradStats(2 * nid + 2, splitEntry.rightGradStat);
            }
            // 2.3.2 the updated split result (tree node/feature/value/gain) on PS,
            updatedIndices[i] = nid;
            updatedSplitFid[i] = splitEntry.fid;
            updatedSplitFvalue[i] = splitEntry.fvalue;
            updatedSplitGain[i] = splitEntry.lossChg;
        } else {
            // 2.3.3 otherwise, the returned histogram contains the gradient info
            splitEntry = histHelper.findBestSplit(histogram);
            LOG.info(String.format("Best split of node[%d]: feature[%d], value[%f], losschg[%f]", nid, splitEntry.getFid(), splitEntry.getFvalue(), splitEntry.getLossChg()));
            // 2.3.4 the updated split result (tree node/feature/value/gain) on PS,
            updatedIndices[i] = nid;
            updatedSplitFid[i] = splitEntry.fid;
            updatedSplitFvalue[i] = splitEntry.fvalue;
            updatedSplitGain[i] = splitEntry.lossChg;
        }
        // 2.3.5 reset this tree node's gradient histogram to 0
        histMat.zero();
    }
    // 3. push split feature to PS
    IntIntVector splitFeatureVector = new IntIntVector(this.activeNode.length, new IntIntDenseVectorStorage(this.activeNode.length));
    // 4. push split value to PS
    IntDoubleVector splitValueVector = new IntDoubleVector(this.activeNode.length, new IntDoubleDenseVectorStorage(this.activeNode.length));
    // 5. push split gain to PS
    IntDoubleVector splitGainVector = new IntDoubleVector(this.activeNode.length, new IntDoubleDenseVectorStorage(this.activeNode.length));
    for (int i = 0; i < updatedIndices.length; i++) {
        splitFeatureVector.set(updatedIndices[i], updatedSplitFid[i]);
        splitValueVector.set(updatedIndices[i], updatedSplitFvalue[i]);
        splitGainVector.set(updatedIndices[i], updatedSplitGain[i]);
    }
    PSModel splitFeat = model.getPSModel(this.param.splitFeaturesName);
    splitFeat.increment(this.currentTree, splitFeatureVector);
    PSModel splitValue = model.getPSModel(this.param.splitValuesName);
    splitValue.increment(this.currentTree, splitValueVector);
    PSModel splitGain = model.getPSModel(this.param.splitGainsName);
    splitGain.increment(this.currentTree, splitGainVector);
    // 6. set phase to AFTER_SPLIT
    // this.phase = GBDTPhase.AFTER_SPLIT;
    LOG.info(String.format("Find split cost: %d ms", System.currentTimeMillis() - startTime));
    // clock
    Set<String> needFlushMatrixSet = new HashSet<String>(3);
    needFlushMatrixSet.add(this.param.splitFeaturesName);
    needFlushMatrixSet.add(this.param.splitValuesName);
    needFlushMatrixSet.add(this.param.splitGainsName);
    needFlushMatrixSet.add(this.param.nodeGradStatsName);
    clockAllMatrix(needFlushMatrixSet, true);
}
Also used : PSModel(com.tencent.angel.ml.model.PSModel) SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry) HistAggrParam(com.tencent.angel.ml.GBDT.psf.HistAggrParam) GBDTGradHistGetRowFunc(com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) IntDoubleDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage) IntIntDenseVectorStorage(com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage)

Aggregations

PSModel (com.tencent.angel.ml.model.PSModel)11 IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)7 IntDoubleDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage)4 IntIntVector (com.tencent.angel.ml.math2.vector.IntIntVector)4 IntIntDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage)2 AngelException (com.tencent.angel.exception.AngelException)1 RegTree (com.tencent.angel.ml.GBDT.algo.RegTree)1 SplitEntry (com.tencent.angel.ml.GBDT.algo.tree.SplitEntry)1 GBDTGradHistGetRowFunc (com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc)1 HistAggrParam (com.tencent.angel.ml.GBDT.psf.HistAggrParam)1 IntDoubleSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage)1 MatrixContext (com.tencent.angel.ml.matrix.MatrixContext)1 QuantifyDoubleFunc (com.tencent.angel.ml.psf.compress.QuantifyDoubleFunc)1 MatrixSaveContext (com.tencent.angel.model.MatrixSaveContext)1 ModelSaveContext (com.tencent.angel.model.ModelSaveContext)1 RowIdColIdValueTextRowFormat (com.tencent.angel.model.output.format.RowIdColIdValueTextRowFormat)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1