use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class RangeRouterUtils method splitIntDoubleVector.
public static KeyValuePart[] splitIntDoubleVector(MatrixMeta matrixMeta, IntDoubleVector vector) {
IntDoubleVectorStorage storage = vector.getStorage();
if (storage.isSparse()) {
// Get keys and values
IntDoubleSparseVectorStorage sparseStorage = (IntDoubleSparseVectorStorage) storage;
int[] keys = sparseStorage.getIndices();
double[] values = sparseStorage.getValues();
return split(matrixMeta, vector.getRowId(), keys, values, false);
} else if (storage.isDense()) {
// Get values
IntDoubleDenseVectorStorage denseStorage = (IntDoubleDenseVectorStorage) storage;
double[] values = denseStorage.getValues();
return split(matrixMeta, vector.getRowId(), values);
} else {
// Key and value array pair
IntDoubleSortedVectorStorage sortStorage = (IntDoubleSortedVectorStorage) storage;
int[] keys = sortStorage.getIndices();
double[] values = sortStorage.getValues();
return split(matrixMeta, vector.getRowId(), keys, values, true);
}
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class CompIntDoubleVectorSplitter method split.
@Override
public Map<PartitionKey, RowUpdateSplit> split(Vector vector, List<PartitionKey> parts) {
IntDoubleVector[] vecParts = ((CompIntDoubleVector) vector).getPartitions();
assert vecParts.length == parts.size();
Map<PartitionKey, RowUpdateSplit> updateSplitMap = new HashMap<>(parts.size());
for (int i = 0; i < vecParts.length; i++) {
updateSplitMap.put(parts.get(i), new CompIntDoubleRowUpdateSplit(vector.getRowId(), vecParts[i], (int) (parts.get(i).getEndCol() - parts.get(i).getStartCol())));
}
return updateSplitMap;
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GBDTController method updateNodeGradStats.
// update node's grad stats on PS
// called during splitting in GradHistHelper, update the grad stats of children nodes after finding the best split
// the root node's stats is updated by leader worker
public void updateNodeGradStats(int nid, GradStats gradStats) throws Exception {
LOG.debug(String.format("Update gradStats of node[%d]: sumGrad[%f], sumHess[%f]", nid, gradStats.sumGrad, gradStats.sumHess));
// 1. create the update
IntDoubleVector vec = new IntDoubleVector(2 * this.activeNode.length, new IntDoubleDenseVectorStorage(2 * this.activeNode.length));
vec.set(nid, gradStats.sumGrad);
vec.set(nid + this.activeNode.length, gradStats.sumHess);
// 2. push the update to PS
PSModel nodeGradStats = this.model.getPSModel(this.param.nodeGradStatsName);
nodeGradStats.increment(this.currentTree, vec);
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GBDTController method getSketch.
// pull the global sketch from PS, only called once by each worker
public void getSketch() throws Exception {
PSModel sketch = model.getPSModel(this.param.sketchName);
LOG.info("------Get sketch from PS------");
long startTime = System.currentTimeMillis();
IntDoubleVector sketchVector = (IntDoubleVector) sketch.getRow(0);
LOG.info(String.format("Get sketch cost: %d ms", System.currentTimeMillis() - startTime));
for (int i = 0; i < sketchVector.getDim(); i++) {
this.sketches[i] = (float) sketchVector.get(i);
}
// number of categorical feature
for (int i = 0; i < cateFeatList.size(); i++) {
int fid = cateFeatList.get(i);
int start = fid * this.param.numSplit;
int splitNum = 1;
for (int j = 0; j < this.param.numSplit; j++) {
if (this.sketches[start + j + 1] > this.sketches[start + j]) {
splitNum++;
} else
break;
}
this.cateFeatNum.put(fid, splitNum);
}
LOG.info("Number of splits of categorical features: " + this.cateFeatNum.entrySet().toString());
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GBDTController method afterSplit.
public void afterSplit() throws Exception {
LOG.info("------After split------");
long startTime = System.currentTimeMillis();
// 1. get split feature
PSModel splitFeatModel = model.getPSModel(this.param.splitFeaturesName);
IntIntVector splitFeatureVec = (IntIntVector) splitFeatModel.getRow(currentTree);
// 2. get split value
PSModel splitValueModel = model.getPSModel(this.param.splitValuesName);
IntDoubleVector splitValueVec = (IntDoubleVector) splitValueModel.getRow(currentTree);
// 3. get split gain
PSModel splitGainModel = model.getPSModel(this.param.splitGainsName);
IntDoubleVector splitGainVec = (IntDoubleVector) splitGainModel.getRow(currentTree);
// 4. get node weight
PSModel nodeGradStatsModel = model.getPSModel(this.param.nodeGradStatsName);
IntDoubleVector nodeGradStatsVec = (IntDoubleVector) nodeGradStatsModel.getRow(currentTree);
LOG.info(String.format("Get split result from PS cost %d ms", System.currentTimeMillis() - startTime));
// 5. split node
LOG.debug(String.format("Split active node: %s", Arrays.toString(this.activeNode)));
int[] preActiveNode = this.activeNode.clone();
for (int nid = 0; nid < this.maxNodeNum; nid++) {
if (preActiveNode[nid] == 1) {
// update local replica
this.splitFeats[nid] = splitFeatureVec.get(nid);
this.splitValues[nid] = splitValueVec.get(nid);
// create AfterSplit task
this.activeNodeStat[nid].set(1);
AfterSplitThread t = new AfterSplitThread(this, nid, splitFeatureVec, splitValueVec, splitGainVec, nodeGradStatsVec);
this.threadPool.submit(t);
}
}
// 2. check thread stats, if all threads finish, return
boolean hasRunning = true;
while (hasRunning) {
hasRunning = false;
for (int nid = 0; nid < this.maxNodeNum; nid++) {
int stat = this.activeNodeStat[nid].get();
if (stat == 1) {
hasRunning = true;
break;
}
}
if (hasRunning) {
LOG.debug("current has running thread");
}
}
updateValidInsPos();
finishCurrentDepth();
LOG.info(String.format("After split cost: %d ms", System.currentTimeMillis() - startTime));
// 6. clock
Set<String> needFlushMatrixSet = new HashSet<String>(4);
needFlushMatrixSet.add(this.param.splitFeaturesName);
needFlushMatrixSet.add(this.param.splitValuesName);
needFlushMatrixSet.add(this.param.splitGainsName);
needFlushMatrixSet.add(this.param.nodeGradStatsName);
clockAllMatrix(needFlushMatrixSet, true);
}
Aggregations