use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.
the class ConstructLocalHistogram method calcWithNodeIdCache.
public int calcWithNodeIdCache(ComContext context, BoostingObjs boostingObjs, Booster booster, HistogramBaseTreeObjs tree, double[] histogram) {
int nodeSize = 0;
int sumFeatureCount = 0;
featureValid.clear();
Arrays.fill(tree.nodeIdCache, -1);
tree.baggingFeaturePool.reset();
for (NodeInfoPair item : tree.queue) {
item.small.baggingFeatures = Bagging.sampleFeatures(boostingObjs, tree.baggingFeaturePool);
for (int i = 0; i < item.small.baggingFeatures.length; ++i) {
featureValid.set(item.small.baggingFeatures[i], true);
}
for (int i = item.small.slice.start; i < item.small.slice.end; ++i) {
tree.nodeIdCache[boostingObjs.indices[i]] = nodeSize;
}
nodeSize++;
sumFeatureCount += boostingObjs.numBaggingFeatures;
if (item.big != null) {
item.big.baggingFeatures = Bagging.sampleFeatures(boostingObjs, tree.baggingFeaturePool);
for (int i = 0; i < item.big.baggingFeatures.length; ++i) {
featureValid.set(item.big.baggingFeatures[i], true);
}
for (int i = item.big.slice.start; i < item.big.slice.end; ++i) {
tree.nodeIdCache[boostingObjs.indices[i]] = nodeSize;
}
nodeSize++;
sumFeatureCount += boostingObjs.numBaggingFeatures;
}
}
if (boostingObjs.params.get(BaseGbdtTrainBatchOp.USE_EPSILON_APPRO_QUANTILE)) {
EpsilonApproQuantile.WQSummary[] summaries = context.getObj(BuildLocalSketch.SKETCH);
int sumFeatureSize = 0;
for (int i = 0; i < boostingObjs.data.getN(); ++i) {
FeatureMeta featureMeta = boostingObjs.data.getFeatureMetas()[i];
validFeatureOffset[i] = sumFeatureSize;
if (featureValid.get(i)) {
sumFeatureSize += DataUtil.getFeatureCategoricalSize(featureMeta, tree.useMissing);
}
}
boostingObjs.data.constructHistogramWithWQSummary(useInstanceCount, nodeSize, featureValid, tree.nodeIdCache, validFeatureOffset, booster.getGradients(), booster.getHessions(), booster.getWeights(), summaries, boostingObjs.executorService, results, featureSplitHistogram);
} else {
int sumFeatureSize = 0;
for (int i = 0; i < boostingObjs.data.getN(); ++i) {
validFeatureOffset[i] = sumFeatureSize;
if (featureValid.get(i)) {
sumFeatureSize += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i], tree.useMissing);
}
}
int validInstanceCount = 0;
for (int i = 0; i < boostingObjs.data.getM(); ++i) {
if (tree.nodeIdCache[i] < 0) {
continue;
}
aligned[validInstanceCount] = i;
validInstanceCount++;
}
LOG.info("taskId: {}, calcWithNodeIdCache start", context.getTaskId());
boostingObjs.data.constructHistogram(useInstanceCount, nodeSize, validInstanceCount, featureValid, tree.nodeIdCache, validFeatureOffset, aligned, booster.getGradients(), booster.getHessions(), booster.getWeights(), boostingObjs.executorService, results, featureSplitHistogram);
}
LOG.info("taskId: {}, calcWithNodeIdCache end", context.getTaskId());
int histogramOffset = 0;
int nodeId = 0;
for (NodeInfoPair item : tree.queue) {
for (int i = 0; i < item.small.baggingFeatures.length; ++i) {
int featureSize = DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[item.small.baggingFeatures[i]], tree.useMissing);
System.arraycopy(featureSplitHistogram, (validFeatureOffset[item.small.baggingFeatures[i]] * nodeSize + nodeId * featureSize) * STEP, histogram, histogramOffset * STEP, featureSize * STEP);
histogramOffset += featureSize;
}
nodeId++;
if (item.big != null) {
for (int i = 0; i < item.big.baggingFeatures.length; ++i) {
int featureSize = DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[item.big.baggingFeatures[i]], tree.useMissing);
System.arraycopy(featureSplitHistogram, (validFeatureOffset[item.big.baggingFeatures[i]] * nodeSize + nodeId * featureSize) * STEP, histogram, histogramOffset * STEP, featureSize * STEP);
histogramOffset += featureSize;
}
nodeId++;
}
}
LOG.info("taskId: {}, sumFeatureCount: {}", context.getTaskId(), sumFeatureCount);
return sumFeatureCount;
}
use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.
the class HistogramBaseTreeObjs method init.
void init(BoostingObjs boostingObjs, List<Row> quantileData) {
if (quantileData != null && !quantileData.isEmpty()) {
quantileModel = new QuantileDiscretizerModelDataConverter().load(quantileData);
}
useMissing = boostingObjs.params.get(USE_MISSING);
useOneHot = boostingObjs.params.get(USE_ONEHOT);
featureMetas = boostingObjs.data.getFeatureMetas();
maxFeatureBins = 0;
allFeatureBins = 0;
for (FeatureMeta featureMeta : boostingObjs.data.getFeatureMetas()) {
int featureSize = DataUtil.getFeatureCategoricalSize(featureMeta, useMissing);
maxFeatureBins = Math.max(maxFeatureBins, featureSize);
allFeatureBins += featureSize;
}
compareIndex4Categorical = new Integer[maxFeatureBins];
featureBinOffset = new int[boostingObjs.numBaggingFeatures];
nodeIdCache = new int[boostingObjs.data.getM()];
maxNodeSize = Math.min(((int) Math.pow(2, boostingObjs.params.get(GbdtTrainParams.MAX_DEPTH) - 1)), boostingObjs.params.get(GbdtTrainParams.MAX_LEAVES));
baggingFeaturePool = new Bagging.BaggingFeaturePool(maxNodeSize, boostingObjs.numBaggingFeatures);
}
use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.
the class DataUtil method createData.
public static Data createData(Params params, List<FeatureMeta> featureMetas, int m, boolean useSort) {
featureMetas.sort(Comparator.comparingInt(FeatureMeta::getIndex));
int column = 0;
int maxColumnIndex = -1;
for (FeatureMeta meta : featureMetas) {
Preconditions.checkState(meta.getIndex() == column++, "There are empty columns. index: %d", meta.getIndex());
maxColumnIndex = Math.max(maxColumnIndex, meta.getIndex());
}
maxColumnIndex += 1;
if (Preprocessing.isSparse(params)) {
return new SparseData(params, featureMetas.toArray(new FeatureMeta[0]), m, maxColumnIndex);
} else {
return new DenseData(params, featureMetas.toArray(new FeatureMeta[0]), m, maxColumnIndex, useSort);
}
}
Aggregations