Search in sources :

Example 6 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class ConstructLocalHistogram method calcWithNodeIdCache.

public int calcWithNodeIdCache(ComContext context, BoostingObjs boostingObjs, Booster booster, HistogramBaseTreeObjs tree, double[] histogram) {
    int nodeSize = 0;
    int sumFeatureCount = 0;
    featureValid.clear();
    Arrays.fill(tree.nodeIdCache, -1);
    tree.baggingFeaturePool.reset();
    for (NodeInfoPair item : tree.queue) {
        item.small.baggingFeatures = Bagging.sampleFeatures(boostingObjs, tree.baggingFeaturePool);
        for (int i = 0; i < item.small.baggingFeatures.length; ++i) {
            featureValid.set(item.small.baggingFeatures[i], true);
        }
        for (int i = item.small.slice.start; i < item.small.slice.end; ++i) {
            tree.nodeIdCache[boostingObjs.indices[i]] = nodeSize;
        }
        nodeSize++;
        sumFeatureCount += boostingObjs.numBaggingFeatures;
        if (item.big != null) {
            item.big.baggingFeatures = Bagging.sampleFeatures(boostingObjs, tree.baggingFeaturePool);
            for (int i = 0; i < item.big.baggingFeatures.length; ++i) {
                featureValid.set(item.big.baggingFeatures[i], true);
            }
            for (int i = item.big.slice.start; i < item.big.slice.end; ++i) {
                tree.nodeIdCache[boostingObjs.indices[i]] = nodeSize;
            }
            nodeSize++;
            sumFeatureCount += boostingObjs.numBaggingFeatures;
        }
    }
    if (boostingObjs.params.get(BaseGbdtTrainBatchOp.USE_EPSILON_APPRO_QUANTILE)) {
        EpsilonApproQuantile.WQSummary[] summaries = context.getObj(BuildLocalSketch.SKETCH);
        int sumFeatureSize = 0;
        for (int i = 0; i < boostingObjs.data.getN(); ++i) {
            FeatureMeta featureMeta = boostingObjs.data.getFeatureMetas()[i];
            validFeatureOffset[i] = sumFeatureSize;
            if (featureValid.get(i)) {
                sumFeatureSize += DataUtil.getFeatureCategoricalSize(featureMeta, tree.useMissing);
            }
        }
        boostingObjs.data.constructHistogramWithWQSummary(useInstanceCount, nodeSize, featureValid, tree.nodeIdCache, validFeatureOffset, booster.getGradients(), booster.getHessions(), booster.getWeights(), summaries, boostingObjs.executorService, results, featureSplitHistogram);
    } else {
        int sumFeatureSize = 0;
        for (int i = 0; i < boostingObjs.data.getN(); ++i) {
            validFeatureOffset[i] = sumFeatureSize;
            if (featureValid.get(i)) {
                sumFeatureSize += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[i], tree.useMissing);
            }
        }
        int validInstanceCount = 0;
        for (int i = 0; i < boostingObjs.data.getM(); ++i) {
            if (tree.nodeIdCache[i] < 0) {
                continue;
            }
            aligned[validInstanceCount] = i;
            validInstanceCount++;
        }
        LOG.info("taskId: {}, calcWithNodeIdCache start", context.getTaskId());
        boostingObjs.data.constructHistogram(useInstanceCount, nodeSize, validInstanceCount, featureValid, tree.nodeIdCache, validFeatureOffset, aligned, booster.getGradients(), booster.getHessions(), booster.getWeights(), boostingObjs.executorService, results, featureSplitHistogram);
    }
    LOG.info("taskId: {}, calcWithNodeIdCache end", context.getTaskId());
    int histogramOffset = 0;
    int nodeId = 0;
    for (NodeInfoPair item : tree.queue) {
        for (int i = 0; i < item.small.baggingFeatures.length; ++i) {
            int featureSize = DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[item.small.baggingFeatures[i]], tree.useMissing);
            System.arraycopy(featureSplitHistogram, (validFeatureOffset[item.small.baggingFeatures[i]] * nodeSize + nodeId * featureSize) * STEP, histogram, histogramOffset * STEP, featureSize * STEP);
            histogramOffset += featureSize;
        }
        nodeId++;
        if (item.big != null) {
            for (int i = 0; i < item.big.baggingFeatures.length; ++i) {
                int featureSize = DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[item.big.baggingFeatures[i]], tree.useMissing);
                System.arraycopy(featureSplitHistogram, (validFeatureOffset[item.big.baggingFeatures[i]] * nodeSize + nodeId * featureSize) * STEP, histogram, histogramOffset * STEP, featureSize * STEP);
                histogramOffset += featureSize;
            }
            nodeId++;
        }
    }
    LOG.info("taskId: {}, sumFeatureCount: {}", context.getTaskId(), sumFeatureCount);
    return sumFeatureCount;
}
Also used : FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta)

Example 7 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class HistogramBaseTreeObjs method init.

void init(BoostingObjs boostingObjs, List<Row> quantileData) {
    if (quantileData != null && !quantileData.isEmpty()) {
        quantileModel = new QuantileDiscretizerModelDataConverter().load(quantileData);
    }
    useMissing = boostingObjs.params.get(USE_MISSING);
    useOneHot = boostingObjs.params.get(USE_ONEHOT);
    featureMetas = boostingObjs.data.getFeatureMetas();
    maxFeatureBins = 0;
    allFeatureBins = 0;
    for (FeatureMeta featureMeta : boostingObjs.data.getFeatureMetas()) {
        int featureSize = DataUtil.getFeatureCategoricalSize(featureMeta, useMissing);
        maxFeatureBins = Math.max(maxFeatureBins, featureSize);
        allFeatureBins += featureSize;
    }
    compareIndex4Categorical = new Integer[maxFeatureBins];
    featureBinOffset = new int[boostingObjs.numBaggingFeatures];
    nodeIdCache = new int[boostingObjs.data.getM()];
    maxNodeSize = Math.min(((int) Math.pow(2, boostingObjs.params.get(GbdtTrainParams.MAX_DEPTH) - 1)), boostingObjs.params.get(GbdtTrainParams.MAX_LEAVES));
    baggingFeaturePool = new Bagging.BaggingFeaturePool(maxNodeSize, boostingObjs.numBaggingFeatures);
}
Also used : FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) QuantileDiscretizerModelDataConverter(com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter)

Example 8 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class DataUtil method createData.

public static Data createData(Params params, List<FeatureMeta> featureMetas, int m, boolean useSort) {
    featureMetas.sort(Comparator.comparingInt(FeatureMeta::getIndex));
    int column = 0;
    int maxColumnIndex = -1;
    for (FeatureMeta meta : featureMetas) {
        Preconditions.checkState(meta.getIndex() == column++, "There are empty columns. index: %d", meta.getIndex());
        maxColumnIndex = Math.max(maxColumnIndex, meta.getIndex());
    }
    maxColumnIndex += 1;
    if (Preprocessing.isSparse(params)) {
        return new SparseData(params, featureMetas.toArray(new FeatureMeta[0]), m, maxColumnIndex);
    } else {
        return new DenseData(params, featureMetas.toArray(new FeatureMeta[0]), m, maxColumnIndex, useSort);
    }
}
Also used : FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta)

Aggregations

FeatureMeta (com.alibaba.alink.operator.common.tree.FeatureMeta)8 EpsilonApproQuantile (com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile)3 Params (org.apache.flink.ml.api.misc.param.Params)3 Row (org.apache.flink.types.Row)3 QuantileDiscretizerModelDataConverter (com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter)2 IterativeComQueue (com.alibaba.alink.common.comqueue.IterativeComQueue)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 Node (com.alibaba.alink.operator.common.tree.Node)1 Preprocessing (com.alibaba.alink.operator.common.tree.Preprocessing)1 TreeModelDataConverter (com.alibaba.alink.operator.common.tree.TreeModelDataConverter)1 BaseGbdtTrainBatchOp (com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp)1 AllReduceT (com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT)1 ReduceScatter (com.alibaba.alink.operator.common.tree.parallelcart.communication.ReduceScatter)1 PaiCriteria (com.alibaba.alink.operator.common.tree.parallelcart.criteria.PaiCriteria)1 LossType (com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType)1 LossUtils (com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils)1 GbdtTrainParams (com.alibaba.alink.params.classification.GbdtTrainParams)1 RandomForestTrainParams (com.alibaba.alink.params.classification.RandomForestTrainParams)1 LambdaMartNdcgParams (com.alibaba.alink.params.regression.LambdaMartNdcgParams)1 ArrayList (java.util.ArrayList)1