use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GBDTController method getSketch.
// pull the global sketch from PS, only called once by each worker
public void getSketch() throws Exception {
PSModel sketch = model.getPSModel(this.param.sketchName);
LOG.info("------Get sketch from PS------");
long startTime = System.currentTimeMillis();
IntDoubleVector sketchVector = (IntDoubleVector) sketch.getRow(0);
LOG.info(String.format("Get sketch cost: %d ms", System.currentTimeMillis() - startTime));
for (int i = 0; i < sketchVector.getDim(); i++) {
this.sketches[i] = (float) sketchVector.get(i);
}
// number of categorical feature
for (int i = 0; i < cateFeatList.size(); i++) {
int fid = cateFeatList.get(i);
int start = fid * this.param.numSplit;
int splitNum = 1;
for (int j = 0; j < this.param.numSplit; j++) {
if (this.sketches[start + j + 1] > this.sketches[start + j]) {
splitNum++;
} else
break;
}
this.cateFeatNum.put(fid, splitNum);
}
LOG.info("Number of splits of categorical features: " + this.cateFeatNum.entrySet().toString());
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GBDTController method afterSplit.
public void afterSplit() throws Exception {
LOG.info("------After split------");
long startTime = System.currentTimeMillis();
// 1. get split feature
PSModel splitFeatModel = model.getPSModel(this.param.splitFeaturesName);
IntIntVector splitFeatureVec = (IntIntVector) splitFeatModel.getRow(currentTree);
// 2. get split value
PSModel splitValueModel = model.getPSModel(this.param.splitValuesName);
IntDoubleVector splitValueVec = (IntDoubleVector) splitValueModel.getRow(currentTree);
// 3. get split gain
PSModel splitGainModel = model.getPSModel(this.param.splitGainsName);
IntDoubleVector splitGainVec = (IntDoubleVector) splitGainModel.getRow(currentTree);
// 4. get node weight
PSModel nodeGradStatsModel = model.getPSModel(this.param.nodeGradStatsName);
IntDoubleVector nodeGradStatsVec = (IntDoubleVector) nodeGradStatsModel.getRow(currentTree);
LOG.info(String.format("Get split result from PS cost %d ms", System.currentTimeMillis() - startTime));
// 5. split node
LOG.debug(String.format("Split active node: %s", Arrays.toString(this.activeNode)));
int[] preActiveNode = this.activeNode.clone();
for (int nid = 0; nid < this.maxNodeNum; nid++) {
if (preActiveNode[nid] == 1) {
// update local replica
this.splitFeats[nid] = splitFeatureVec.get(nid);
this.splitValues[nid] = splitValueVec.get(nid);
// create AfterSplit task
this.activeNodeStat[nid].set(1);
AfterSplitThread t = new AfterSplitThread(this, nid, splitFeatureVec, splitValueVec, splitGainVec, nodeGradStatsVec);
this.threadPool.submit(t);
}
}
// 2. check thread stats, if all threads finish, return
boolean hasRunning = true;
while (hasRunning) {
hasRunning = false;
for (int nid = 0; nid < this.maxNodeNum; nid++) {
int stat = this.activeNodeStat[nid].get();
if (stat == 1) {
hasRunning = true;
break;
}
}
if (hasRunning) {
LOG.debug("current has running thread");
}
}
updateValidInsPos();
finishCurrentDepth();
LOG.info(String.format("After split cost: %d ms", System.currentTimeMillis() - startTime));
// 6. clock
Set<String> needFlushMatrixSet = new HashSet<String>(4);
needFlushMatrixSet.add(this.param.splitFeaturesName);
needFlushMatrixSet.add(this.param.splitValuesName);
needFlushMatrixSet.add(this.param.splitGainsName);
needFlushMatrixSet.add(this.param.nodeGradStatsName);
clockAllMatrix(needFlushMatrixSet, true);
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GBDTController method createSketch.
// create data sketch, push candidate split value to PS
public void createSketch() throws Exception {
PSModel sketch = model.getPSModel(this.param.sketchName);
PSModel cateFeat = model.getPSModel(this.param.cateFeatureName);
if (taskContext.getTaskIndex() == 0) {
LOG.info("------Create sketch------");
long startTime = System.currentTimeMillis();
IntDoubleVector sketchVec = new IntDoubleVector(this.param.numFeature * this.param.numSplit, new IntDoubleDenseVectorStorage(new double[this.param.numFeature * this.param.numSplit]));
IntDoubleVector cateFeatVec = null;
if (!this.cateFeatList.isEmpty()) {
cateFeatVec = new IntDoubleVector(this.cateFeatList.size() * this.param.numSplit, new IntDoubleDenseVectorStorage(new double[this.cateFeatList.size() * this.param.numSplit]));
}
// 1. calculate candidate split value
float[][] splits = TYahooSketchSplit.getSplitValue(this.trainDataStore, this.param.numSplit, this.cateFeatList);
if (splits.length == this.param.numFeature && splits[0].length == this.param.numSplit) {
for (int fid = 0; fid < splits.length; fid++) {
if (cateFeatList.contains(fid)) {
continue;
}
for (int j = 0; j < splits[fid].length; j++) {
sketchVec.set(fid * this.param.numSplit + j, splits[fid][j]);
}
}
} else {
LOG.error("Incompatible sketches size.");
}
// categorical features
if (!this.cateFeatList.isEmpty()) {
Collections.sort(this.cateFeatList);
for (int i = 0; i < this.cateFeatList.size(); i++) {
int fid = this.cateFeatList.get(i);
int start = i * this.param.numSplit;
for (int j = 0; j < splits[fid].length; j++) {
if (splits[fid][j] == 0 && j > 0)
break;
cateFeatVec.set(start + j, splits[fid][j]);
}
}
}
// 2. push local sketch to PS
sketch.increment(0, sketchVec);
if (null != cateFeatVec) {
cateFeat.increment(this.taskContext.getTaskIndex(), cateFeatVec);
}
LOG.info(String.format("Create sketch cost: %d ms", System.currentTimeMillis() - startTime));
}
Set<String> needFlushMatrixSet = new HashSet<String>(1);
needFlushMatrixSet.add(this.param.sketchName);
needFlushMatrixSet.add(this.param.cateFeatureName);
clockAllMatrix(needFlushMatrixSet, true);
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class GradHistThread method run.
@Override
public void run() {
LOG.debug(String.format("Run active node[%d]", this.nid));
// 1. name of this node's grad histogram on PS
String histParaName = this.controller.param.gradHistNamePrefix + nid;
// 2. build the grad histogram of this node
GradHistHelper histMaker = new GradHistHelper(this.controller, this.nid);
IntDoubleVector histogram = histMaker.buildHistogram(insStart, insEnd);
int bytesPerItem = this.controller.taskContext.getConf().getInt(MLConf.ANGEL_COMPRESS_BYTES(), MLConf.DEFAULT_ANGEL_COMPRESS_BYTES());
if (bytesPerItem < 1 || bytesPerItem > 8) {
LOG.info("Invalid compress configuration: " + bytesPerItem + ", it should be [1,8].");
bytesPerItem = MLConf.DEFAULT_ANGEL_COMPRESS_BYTES();
}
// 3. push the histograms to PS
try {
if (bytesPerItem == 8) {
this.model.increment(0, histogram);
} else {
QuantifyDoubleFunc func = new QuantifyDoubleFunc(this.model.getMatrixId(), 0, histogram, bytesPerItem * 8);
this.model.update(func);
}
} catch (Exception e) {
LOG.error(histParaName + " increment failed, ", e);
}
// 4. reset thread stats to finished
this.controller.activeNodeStat[this.nid].decrementAndGet();
LOG.debug(String.format("Active node[%d] finish", this.nid));
}
use of com.tencent.angel.ml.math2.vector.IntDoubleVector in project angel by Tencent.
the class HistCalThread method call.
@Override
public Boolean call() throws Exception {
GradHistHelper histMaker = new GradHistHelper(this.controller, this.nid);
IntDoubleVector localHist = histMaker.buildHistogram(start, end);
LOG.debug(String.format("Batch histogram[%d]: %s", nid, Arrays.toString(localHist.get(new int[] { 0, 1, 2, 3, 4, 5 }))));
synchronized (this) {
this.controller.histCache[nid].iadd(localHist);
LOG.debug(String.format("Calculated histogram[%d]: %s", nid, Arrays.toString(this.controller.histCache[nid].get(new int[] { 0, 1, 2, 3, 4, 5 }))));
}
return true;
}
Aggregations