use of com.tencent.angel.ml.math2.vector.IntIntVector in project angel by Tencent.
the class GBDTController method afterSplit.
public void afterSplit() throws Exception {
LOG.info("------After split------");
long startTime = System.currentTimeMillis();
// 1. get split feature
PSModel splitFeatModel = model.getPSModel(this.param.splitFeaturesName);
IntIntVector splitFeatureVec = (IntIntVector) splitFeatModel.getRow(currentTree);
// 2. get split value
PSModel splitValueModel = model.getPSModel(this.param.splitValuesName);
IntDoubleVector splitValueVec = (IntDoubleVector) splitValueModel.getRow(currentTree);
// 3. get split gain
PSModel splitGainModel = model.getPSModel(this.param.splitGainsName);
IntDoubleVector splitGainVec = (IntDoubleVector) splitGainModel.getRow(currentTree);
// 4. get node weight
PSModel nodeGradStatsModel = model.getPSModel(this.param.nodeGradStatsName);
IntDoubleVector nodeGradStatsVec = (IntDoubleVector) nodeGradStatsModel.getRow(currentTree);
LOG.info(String.format("Get split result from PS cost %d ms", System.currentTimeMillis() - startTime));
// 5. split node
LOG.debug(String.format("Split active node: %s", Arrays.toString(this.activeNode)));
int[] preActiveNode = this.activeNode.clone();
for (int nid = 0; nid < this.maxNodeNum; nid++) {
if (preActiveNode[nid] == 1) {
// update local replica
this.splitFeats[nid] = splitFeatureVec.get(nid);
this.splitValues[nid] = splitValueVec.get(nid);
// create AfterSplit task
this.activeNodeStat[nid].set(1);
AfterSplitThread t = new AfterSplitThread(this, nid, splitFeatureVec, splitValueVec, splitGainVec, nodeGradStatsVec);
this.threadPool.submit(t);
}
}
// 2. check thread stats, if all threads finish, return
boolean hasRunning = true;
while (hasRunning) {
hasRunning = false;
for (int nid = 0; nid < this.maxNodeNum; nid++) {
int stat = this.activeNodeStat[nid].get();
if (stat == 1) {
hasRunning = true;
break;
}
}
if (hasRunning) {
LOG.debug("current has running thread");
}
}
updateValidInsPos();
finishCurrentDepth();
LOG.info(String.format("After split cost: %d ms", System.currentTimeMillis() - startTime));
// 6. clock
Set<String> needFlushMatrixSet = new HashSet<String>(4);
needFlushMatrixSet.add(this.param.splitFeaturesName);
needFlushMatrixSet.add(this.param.splitValuesName);
needFlushMatrixSet.add(this.param.splitGainsName);
needFlushMatrixSet.add(this.param.nodeGradStatsName);
clockAllMatrix(needFlushMatrixSet, true);
}
use of com.tencent.angel.ml.math2.vector.IntIntVector in project angel by Tencent.
the class GBDTController method createNewTree.
// create new tree
// pull sampled features, initialize tree nodes, reset active nodes, reset instance position,
// calculate gradient
public void createNewTree() throws Exception {
LOG.info("------Create new tree------");
long startTime = System.currentTimeMillis();
// 1. create new tree, initialize tree nodes and node stats
RegTree tree = new RegTree(this.param);
tree.initTreeNodes();
this.currentDepth = 1;
this.forest[this.currentTree] = tree;
// 2. initialize feature set, if sampled, get from PS, otherwise use all the features
if (this.param.colSample < 1) {
// 2.1. pull the sampled features of the current tree
PSModel featSample = model.getPSModel(this.param.sampledFeaturesName);
IntIntVector sampleFeatureVector = (IntIntVector) featSample.getRow(this.currentTree);
this.fSet = sampleFeatureVector.getStorage().getValues();
calfPos();
// this.forest[this.currentTree].fset = sampleFeatureVector.getStorage().getValues();
} else {
// 2.2. if use all the features, only called one
if (null == this.fSet) {
this.fSet = new int[this.trainDataStore.featureMeta.numFeature];
Arrays.setAll(this.fSet, i -> i);
this.fPos = new int[this.trainDataStore.featureMeta.numFeature];
Arrays.setAll(this.fPos, i -> i);
}
}
// 3. reset active tree nodes, set all tree nodes to inactive, set thread status to idle
for (int nid = 0; nid < this.maxNodeNum; nid++) {
resetActiveTNodes(nid);
}
// 4. set root node to active
addActiveNode(0);
// 5. reset instance position, set the root node's span
this.nodePosStart[0] = 0;
this.nodePosEnd[0] = this.instancePos.length - 1;
for (int nid = 1; nid < this.maxNodeNum; nid++) {
this.nodePosStart[nid] = -1;
this.nodePosEnd[nid] = -1;
}
// reset position of validation instance
Arrays.setAll(this.validInsPos, i -> 0);
// 6. calculate gradient
calGradPairs();
LOG.info(String.format("Create new tree cost: %d ms", System.currentTimeMillis() - startTime));
}
use of com.tencent.angel.ml.math2.vector.IntIntVector in project angel by Tencent.
the class MergeUtils method combineServerIntIntRowSplits.
private static Vector combineServerIntIntRowSplits(List<ServerRow> rowSplits, MatrixMeta matrixMeta, int rowIndex) {
int colNum = (int) matrixMeta.getColNum();
int elemNum = 0;
int size = rowSplits.size();
for (int i = 0; i < size; i++) {
elemNum += rowSplits.get(i).size();
}
IntIntVector row;
if (matrixMeta.isHash()) {
row = VFactory.sparseIntVector(colNum, elemNum);
} else {
if (elemNum >= (int) (storageConvFactor * colNum)) {
row = VFactory.denseIntVector(colNum);
} else {
row = VFactory.sparseIntVector(colNum, elemNum);
}
}
row.setMatrixId(matrixMeta.getId());
row.setRowId(rowIndex);
Collections.sort(rowSplits, serverRowComp);
for (int i = 0; i < size; i++) {
if (rowSplits.get(i) == null) {
continue;
}
((ServerIntIntRow) rowSplits.get(i)).mergeTo(row);
}
return row;
}
use of com.tencent.angel.ml.math2.vector.IntIntVector in project angel by Tencent.
the class MergeUtils method combineIntIntIndexRowSplits.
// //////////////////////////////////////////////////////////////////////////////
// Combine Int key Int value vector
// //////////////////////////////////////////////////////////////////////////////
public static Vector combineIntIntIndexRowSplits(int matrixId, int rowId, int resultSize, KeyPart[] keyParts, ValuePart[] valueParts, MatrixMeta matrixMeta) {
IntIntVector vector = VFactory.sparseIntVector((int) matrixMeta.getColNum(), resultSize);
for (int i = 0; i < keyParts.length; i++) {
mergeTo(vector, keyParts[i], (IntValuesPart) valueParts[i]);
}
vector.setRowId(rowId);
vector.setMatrixId(matrixId);
return vector;
}
use of com.tencent.angel.ml.math2.vector.IntIntVector in project angel by Tencent.
the class RowSplitCombineUtils method combineIntIntIndexRowSplits.
// //////////////////////////////////////////////////////////////////////////////
// Combine Int key Int value vector
// //////////////////////////////////////////////////////////////////////////////
public static Vector combineIntIntIndexRowSplits(int matrixId, int rowId, int resultSize, KeyPart[] keyParts, ValuePart[] valueParts, MatrixMeta matrixMeta) {
IntIntVector vector = VFactory.sparseIntVector((int) matrixMeta.getColNum(), resultSize);
for (int i = 0; i < keyParts.length; i++) {
mergeTo(vector, keyParts[i], (IntValuesPart) valueParts[i]);
}
vector.setRowId(rowId);
vector.setMatrixId(matrixId);
return vector;
}
Aggregations