Search in sources :

Example 1 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class BaseGbdtTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    LOG.info("gbdt train start");
    if (!Preprocessing.isSparse(getParams())) {
        getParams().set(HasCategoricalCols.CATEGORICAL_COLS, TableUtil.getCategoricalCols(in.getSchema(), getParams().get(GbdtTrainParams.FEATURE_COLS), getParams().contains(GbdtTrainParams.CATEGORICAL_COLS) ? getParams().get(GbdtTrainParams.CATEGORICAL_COLS) : null));
    }
    LossType loss = getParams().get(LossUtils.LOSS_TYPE);
    getParams().set(ALGO_TYPE, LossUtils.lossTypeToInt(loss));
    rewriteLabelType(in.getSchema(), getParams());
    if (!Preprocessing.isSparse(getParams())) {
        getParams().set(ModelParamName.FEATURE_TYPES, FlinkTypeConverter.getTypeString(TableUtil.findColTypes(in.getSchema(), getParams().get(GbdtTrainParams.FEATURE_COLS))));
    }
    if (LossUtils.isRanking(getParams().get(LossUtils.LOSS_TYPE))) {
        if (!getParams().contains(LambdaMartNdcgParams.GROUP_COL)) {
            throw new IllegalArgumentException("Group column should be set in ranking loss function.");
        }
    }
    String[] trainColNames = trainColsWithGroup();
    // check label if has null value or not.
    final String labelColName = this.getParams().get(HasLabelCol.LABEL_COL);
    final int labelColIdx = TableUtil.findColIndex(in.getSchema(), labelColName);
    in = new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), in.getDataSet().map(new MapFunction<Row, Row>() {

        @Override
        public Row map(Row row) throws Exception {
            if (null == row.getField(labelColIdx)) {
                throw new RuntimeException("label col has null values.");
            }
            return row;
        }
    }), in.getSchema())).setMLEnvironmentId(in.getMLEnvironmentId());
    in = Preprocessing.select(in, trainColNames);
    DataSet<Object[]> labels = Preprocessing.generateLabels(in, getParams(), LossUtils.isRegression(loss) || LossUtils.isRanking(loss));
    if (LossUtils.isClassification(loss)) {
        labels = labels.map(new CheckNumLabels4BinaryClassifier());
    }
    DataSet<Row> trainDataSet;
    BatchOperator<?> stringIndexerModel;
    BatchOperator<?> quantileModel;
    if (getParams().get(USE_ONEHOT)) {
        // create empty string indexer model.
        stringIndexerModel = Preprocessing.generateStringIndexerModel(in, new Params());
        // create empty quantile model.
        quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, new Params().set(HasFeatureCols.FEATURE_COLS, new String[] {}).set(HasCategoricalCols.CATEGORICAL_COLS, new String[] {}));
        trainDataSet = Preprocessing.castLabel(in, getParams(), labels, LossUtils.isRegression(loss) || LossUtils.isRanking(loss)).getDataSet();
    } else if (getParams().get(USE_EPSILON_APPRO_QUANTILE)) {
        // create string indexer model
        stringIndexerModel = Preprocessing.generateStringIndexerModel(in, getParams());
        // create empty quantile model
        quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, new Params().set(HasFeatureCols.FEATURE_COLS, new String[] {}).set(HasCategoricalCols.CATEGORICAL_COLS, new String[] {}));
        trainDataSet = Preprocessing.castLabel(Preprocessing.isSparse(getParams()) ? in : Preprocessing.castContinuousCols(Preprocessing.castCategoricalCols(in, stringIndexerModel, getParams()), getParams()), getParams(), labels, LossUtils.isRegression(loss) || LossUtils.isRanking(loss)).getDataSet();
    } else {
        stringIndexerModel = Preprocessing.generateStringIndexerModel(in, getParams());
        quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, getParams());
        trainDataSet = Preprocessing.castLabel(Preprocessing.castToQuantile(Preprocessing.isSparse(getParams()) ? in : Preprocessing.castContinuousCols(Preprocessing.castCategoricalCols(in, stringIndexerModel, getParams()), getParams()), quantileModel, getParams()), getParams(), labels, LossUtils.isRegression(loss) || LossUtils.isRanking(loss)).getDataSet();
    }
    if (LossUtils.isRanking(getParams().get(LossUtils.LOSS_TYPE))) {
        trainDataSet = trainDataSet.partitionCustom(new Partitioner<Number>() {

            private static final long serialVersionUID = -7790649477852624964L;

            @Override
            public int partition(Number key, int numPartitions) {
                return (int) (key.longValue() % numPartitions);
            }
        }, 0);
    }
    DataSet<Tuple2<Double, Long>> sum = trainDataSet.mapPartition(new MapPartitionFunction<Row, Tuple2<Double, Long>>() {

        private static final long serialVersionUID = -8333738060239409640L;

        @Override
        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Double, Long>> collector) throws Exception {
            double sum = 0.;
            long cnt = 0;
            for (Row row : iterable) {
                sum += ((Number) row.getField(row.getArity() - 1)).doubleValue();
                cnt++;
            }
            collector.collect(Tuple2.of(sum, cnt));
        }
    }).reduce(new ReduceFunction<Tuple2<Double, Long>>() {

        private static final long serialVersionUID = -6464200385237876961L;

        @Override
        public Tuple2<Double, Long> reduce(Tuple2<Double, Long> t0, Tuple2<Double, Long> t1) throws Exception {
            return Tuple2.of(t0.f0 + t1.f0, t0.f1 + t1.f1);
        }
    });
    DataSet<FeatureMeta> featureMetas;
    if (getParams().get(USE_ONEHOT)) {
        featureMetas = DataUtil.createOneHotFeatureMeta(trainDataSet, getParams(), trainColNames);
    } else if (getParams().get(USE_EPSILON_APPRO_QUANTILE)) {
        featureMetas = DataUtil.createEpsilonApproQuantileFeatureMeta(trainDataSet, stringIndexerModel.getDataSet(), getParams(), trainColNames, getMLEnvironmentId());
    } else {
        featureMetas = DataUtil.createFeatureMetas(quantileModel.getDataSet(), stringIndexerModel.getDataSet(), getParams());
    }
    {
        getParams().set(BoosterType.BOOSTER_TYPE, BoosterType.HESSION_BASE);
        getParams().set(CriteriaType.CRITERIA_TYPE, CriteriaType.valueOf(getParams().get(GbdtTrainParams.CRITERIA).toString()));
        if (getParams().get(GbdtTrainParams.NEWTON_STEP)) {
            getParams().set(LeafScoreUpdaterType.LEAF_SCORE_UPDATER_TYPE, LeafScoreUpdaterType.NEWTON_SINGLE_STEP_UPDATER);
        } else {
            getParams().set(LeafScoreUpdaterType.LEAF_SCORE_UPDATER_TYPE, LeafScoreUpdaterType.WEIGHT_AVG_UPDATER);
        }
    }
    IterativeComQueue comQueue = new IterativeComQueue().initWithPartitionedData("trainData", trainDataSet).initWithBroadcastData("gbdt.y.sum", sum).initWithBroadcastData("quantileModel", quantileModel.getDataSet()).initWithBroadcastData("stringIndexerModel", stringIndexerModel.getDataSet()).initWithBroadcastData("labels", labels).initWithBroadcastData("featureMetas", featureMetas).add(new InitBoostingObjs(getParams())).add(new Boosting()).add(new Bagging()).add(new InitTreeObjs());
    if (getParams().get(USE_EPSILON_APPRO_QUANTILE)) {
        comQueue.add(new BuildLocalSketch()).add(new AllReduceT<>(BuildLocalSketch.SKETCH, BuildLocalSketch.FEATURE_SKETCH_LENGTH, new BuildLocalSketch.SketchReducer(getParams()), EpsilonApproQuantile.WQSummary.class)).add(new FinalizeBuildSketch());
    }
    comQueue.add(new ConstructLocalHistogram()).add(new ReduceScatter("histogram", "histogram", "recvcnts", AllReduce.SUM)).add(new CalcFeatureGain()).add(new AllReduceT<>("best", "bestLength", new NodeReducer(), Node.class)).add(new SplitInstances()).add(new UpdateLeafScore()).add(new UpdatePredictionScore()).setCompareCriterionOfNode0(new TerminateCriterion()).closeWith(new SaveModel(getParams()));
    DataSet<Row> model = comQueue.exec();
    setOutput(model, new TreeModelDataConverter(FlinkTypeConverter.getFlinkType(getParams().get(ModelParamName.LABEL_TYPE_NAME))).getModelSchema());
    this.setSideOutputTables(new Table[] { DataSetConversionUtil.toTable(getMLEnvironmentId(), model.reduceGroup(new TreeModelDataConverter.FeatureImportanceReducer()), new String[] { getParams().get(TreeModelDataConverter.IMPORTANCE_FIRST_COL), getParams().get(TreeModelDataConverter.IMPORTANCE_SECOND_COL) }, new TypeInformation[] { Types.STRING, Types.DOUBLE }) });
    return (T) this;
}
Also used : TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) MapPartitionFunction(org.apache.flink.api.common.functions.MapPartitionFunction) FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) IterativeComQueue(com.alibaba.alink.common.comqueue.IterativeComQueue) LambdaMartNdcgParams(com.alibaba.alink.params.regression.LambdaMartNdcgParams) GbdtTrainParams(com.alibaba.alink.params.classification.GbdtTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) TreeModelDataConverter(com.alibaba.alink.operator.common.tree.TreeModelDataConverter) Row(org.apache.flink.types.Row) ReduceScatter(com.alibaba.alink.operator.common.tree.parallelcart.communication.ReduceScatter) AllReduceT(com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT) AllReduceT(com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT) Collector(org.apache.flink.util.Collector) Partitioner(org.apache.flink.api.common.functions.Partitioner) Tuple2(org.apache.flink.api.java.tuple.Tuple2) LossType(com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType)

Example 2 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class DenseData method sort.

@Override
public void sort() {
    for (int i = 0; i < n; ++i) {
        final FeatureMeta featureMeta = featureMetas[i];
        if (featureMeta.getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
            Arrays.sort(sortedValues, featureMeta.getIndex() * m, (featureMeta.getIndex() + 1) * m, (o1, o2) -> {
                boolean isMissing1 = Preprocessing.isMissing(o1.val, featureMeta, zeroAsMissing);
                boolean isMissing2 = Preprocessing.isMissing(o2.val, featureMeta, zeroAsMissing);
                if (isMissing1 && isMissing2) {
                    return 0;
                } else if (isMissing1) {
                    return 1;
                } else if (isMissing2) {
                    return -1;
                } else {
                    return Double.compare(o1.val, o2.val);
                }
            });
            for (int j = 0; j < m; ++j) {
                orderedIndices[i * m + j] = i * m + j;
            }
            Arrays.sort(orderedIndices, i * m, (i + 1) * m, Comparator.comparingInt(o -> sortedValues[o].index));
        }
    }
}
Also used : Arrays(java.util.Arrays) FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) PaiCriteria(com.alibaba.alink.operator.common.tree.parallelcart.criteria.PaiCriteria) Preprocessing(com.alibaba.alink.operator.common.tree.Preprocessing) Node(com.alibaba.alink.operator.common.tree.Node) List(java.util.List) Future(java.util.concurrent.Future) EpsilonApproQuantile(com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile) LossUtils(com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils) Row(org.apache.flink.types.Row) BitSet(java.util.BitSet) BaseGbdtTrainBatchOp(com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp) Comparator(java.util.Comparator) Params(org.apache.flink.ml.api.misc.param.Params) ExecutorService(java.util.concurrent.ExecutorService) FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta)

Example 3 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class DenseData method constructHistogramWithWQSummary.

@Override
public void constructHistogramWithWQSummary(boolean useInstanceCount, int nodeSize, BitSet featureValid, int[] nodeIdCache, int[] validFeatureOffset, double[] gradients, double[] hessions, double[] weights, EpsilonApproQuantile.WQSummary[] summaries, ExecutorService executorService, Future<?>[] futures, double[] featureSplitHistogram) {
    final int step = 4;
    for (int i = 0, index = 0; i < getN(); ++i) {
        final FeatureMeta featureMeta = featureMetas[i];
        boolean isContinuous = featureMeta.getType().equals(FeatureMeta.FeatureType.CONTINUOUS);
        futures[i] = null;
        if (!featureValid.get(i)) {
            if (isContinuous) {
                index++;
            }
            continue;
        }
        if (isContinuous) {
            EpsilonApproQuantile.WQSummary summary = summaries[index];
            final int dataOffset = getM() * i;
            final int featureSize = DataUtil.getFeatureCategoricalSize(featureMetas[i], useMissing);
            final int histogramOffset = validFeatureOffset[i] * nodeSize * step;
            final int nextHistogramOffset = histogramOffset + featureSize * nodeSize * step;
            if (useInstanceCount) {
                futures[i] = executorService.submit(() -> {
                    int cursor = 0;
                    Arrays.fill(featureSplitHistogram, histogramOffset, nextHistogramOffset, 0.0);
                    for (int j = 0; j < m; ++j) {
                        final int localRowIndex = sortedValues[dataOffset + j].index;
                        if (nodeIdCache[localRowIndex] < 0) {
                            continue;
                        }
                        while ((cursor < summary.entries.size() && summary.entries.get(cursor).value < sortedValues[dataOffset + j].val)) {
                            cursor++;
                        }
                        if (Preprocessing.isMissing(sortedValues[dataOffset + j].val, featureMeta, zeroAsMissing)) {
                            cursor = summary.entries.size();
                        }
                        final int localValue = cursor;
                        final int node = nodeIdCache[localRowIndex];
                        final int counterIndex = (node * featureSize + localValue) * step + histogramOffset;
                        featureSplitHistogram[counterIndex] += gradients[localRowIndex];
                        featureSplitHistogram[counterIndex + 1] += hessions[localRowIndex];
                        featureSplitHistogram[counterIndex + 2] += weights[localRowIndex];
                        if (weights[localRowIndex] > PaiCriteria.PAI_EPS) {
                            featureSplitHistogram[counterIndex + 3] += 1.0;
                        }
                    }
                });
            } else {
                futures[i] = executorService.submit(() -> {
                    int cursor = 0;
                    Arrays.fill(featureSplitHistogram, histogramOffset, nextHistogramOffset, 0.0);
                    for (int j = 0; j < m; ++j) {
                        final int localRowIndex = sortedValues[dataOffset + j].index;
                        if (nodeIdCache[localRowIndex] < 0) {
                            continue;
                        }
                        while ((cursor < summary.entries.size() && summary.entries.get(cursor).value < sortedValues[dataOffset + j].val)) {
                            cursor++;
                        }
                        if (Preprocessing.isMissing(sortedValues[dataOffset + j].val, featureMeta, zeroAsMissing)) {
                            cursor = summary.entries.size();
                        }
                        final int localValue = cursor;
                        final int node = nodeIdCache[localRowIndex];
                        final int counterIndex = (node * featureSize + localValue) * step + histogramOffset;
                        featureSplitHistogram[counterIndex] += gradients[localRowIndex];
                        featureSplitHistogram[counterIndex + 1] += hessions[localRowIndex];
                        featureSplitHistogram[counterIndex + 2] += weights[localRowIndex];
                        featureSplitHistogram[counterIndex + 3] += 1.0;
                    }
                });
            }
            index++;
        } else {
            final int dataOffset = getM() * i;
            final int featureSize = DataUtil.getFeatureCategoricalSize(featureMetas[i], useMissing);
            final int histogramOffset = validFeatureOffset[i] * nodeSize * step;
            final int nextHistogramOffset = histogramOffset + featureSize * nodeSize * step;
            if (useInstanceCount) {
                futures[i] = executorService.submit(() -> {
                    Arrays.fill(featureSplitHistogram, histogramOffset, nextHistogramOffset, 0.0);
                    for (int j = 0; j < m; ++j) {
                        final int localRowIndex = sortedValues[dataOffset + j].index;
                        if (nodeIdCache[localRowIndex] < 0) {
                            continue;
                        }
                        final int localValue = (int) sortedValues[dataOffset + j].val;
                        final int node = nodeIdCache[localRowIndex];
                        final int counterIndex = (node * featureSize + localValue) * step + histogramOffset;
                        featureSplitHistogram[counterIndex] += gradients[localRowIndex];
                        featureSplitHistogram[counterIndex + 1] += hessions[localRowIndex];
                        featureSplitHistogram[counterIndex + 2] += weights[localRowIndex];
                        if (weights[localRowIndex] > PaiCriteria.PAI_EPS) {
                            featureSplitHistogram[counterIndex + 3] += 1.0;
                        }
                    }
                });
            } else {
                futures[i] = executorService.submit(() -> {
                    Arrays.fill(featureSplitHistogram, histogramOffset, nextHistogramOffset, 0.0);
                    for (int j = 0; j < m; ++j) {
                        final int localRowIndex = sortedValues[dataOffset + j].index;
                        if (nodeIdCache[localRowIndex] < 0) {
                            continue;
                        }
                        final int localValue = (int) sortedValues[dataOffset + j].val;
                        final int node = nodeIdCache[localRowIndex];
                        final int counterIndex = (node * featureSize + localValue) * step + histogramOffset;
                        featureSplitHistogram[counterIndex] += gradients[localRowIndex];
                        featureSplitHistogram[counterIndex + 1] += hessions[localRowIndex];
                        featureSplitHistogram[counterIndex + 2] += weights[localRowIndex];
                        featureSplitHistogram[counterIndex + 3] += 1.0;
                    }
                });
            }
        }
    }
    for (Future<?> future : futures) {
        if (future == null) {
            continue;
        }
        try {
            future.get();
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }
}
Also used : FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) EpsilonApproQuantile(com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile)

Example 4 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class DenseData method createWQSummary.

@Override
public EpsilonApproQuantile.SketchEntry[] createWQSummary(int maxSize, double eps, EpsilonApproQuantile.SketchEntry[] buffer, double[] dynamicWeights, BitSet validFlags) {
    for (int i = 0, index = 0; i < n; ++i) {
        final FeatureMeta featureMeta = featureMetas[i];
        if (featureMeta.getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
            buffer[index].sumTotal = 0.0;
            int featureOffSet = i * m;
            for (int j = 0; j < m; ++j) {
                IndexedValue v = sortedValues[featureOffSet + j];
                if (validFlags.get(v.index) && !Preprocessing.isMissing(v.val, featureMeta, zeroAsMissing)) {
                    buffer[index].sumTotal += dynamicWeights[v.index];
                }
            }
            index++;
        }
    }
    for (int i = 0, index = 0; i < n; ++i) {
        final FeatureMeta featureMeta = featureMetas[i];
        if (featureMeta.getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
            int start = 0;
            int end = m;
            EpsilonApproQuantile.SketchEntry entry = buffer[index];
            if (start == end || entry.sumTotal == 0.0) {
                // empty or all elements are null.
                index++;
                continue;
            }
            entry.init(maxSize);
            int featureOffSet = i * m;
            for (int j = start; j < end; ++j) {
                IndexedValue v = sortedValues[featureOffSet + j];
                if (validFlags.get(v.index) && !Preprocessing.isMissing(v.val, featureMeta, zeroAsMissing)) {
                    entry.push(v.val, dynamicWeights[v.index], maxSize);
                }
            }
            entry.finalize(maxSize);
            index++;
        }
    }
    return buffer;
}
Also used : FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) EpsilonApproQuantile(com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile)

Example 5 with FeatureMeta

use of com.alibaba.alink.operator.common.tree.FeatureMeta in project Alink by alibaba.

the class TreeInitObj method calc.

@Override
public void calc(ComContext context) {
    if (context.getStepNo() != 1) {
        return;
    }
    List<Row> dataRows = context.getObj("treeInput");
    List<Row> quantileModel = context.getObj("quantileModel");
    List<Row> stringIndexerModel = context.getObj("stringIndexerModel");
    List<Object[]> labels = context.getObj("labels");
    int nLocalRow = dataRows == null ? 0 : dataRows.size();
    Params localParams = params.clone();
    localParams.set(TASK_ID, context.getTaskId());
    localParams.set(NUM_OF_SUBTASKS, context.getNumTask());
    localParams.set(N_LOCAL_ROW, nLocalRow);
    QuantileDiscretizerModelDataConverter quantileDiscretizerModel = initialMapping(quantileModel);
    List<String> lookUpColNames = new ArrayList<>();
    if (params.get(RandomForestTrainParams.CATEGORICAL_COLS) != null) {
        lookUpColNames.addAll(Arrays.asList(params.get(RandomForestTrainParams.CATEGORICAL_COLS)));
    }
    Map<String, Integer> categoricalColsSize = TreeUtil.extractCategoricalColsSize(stringIndexerModel, lookUpColNames.toArray(new String[0]));
    if (!Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
        categoricalColsSize.put(params.get(RandomForestTrainParams.LABEL_COL), labels.get(0).length);
    }
    FeatureMeta[] featureMetas = TreeUtil.getFeatureMeta(params.get(RandomForestTrainParams.FEATURE_COLS), categoricalColsSize);
    FeatureMeta labelMeta = TreeUtil.getLabelMeta(params.get(RandomForestTrainParams.LABEL_COL), params.get(RandomForestTrainParams.FEATURE_COLS).length, categoricalColsSize);
    TreeObj treeObj;
    if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
        treeObj = new RegObj(localParams, quantileDiscretizerModel, featureMetas, labelMeta);
    } else {
        treeObj = new ClassifierObj(localParams, quantileDiscretizerModel, featureMetas, labelMeta);
    }
    int nFeatureCol = localParams.get(RandomForestTrainParams.FEATURE_COLS).length;
    int[] data = new int[nFeatureCol * nLocalRow];
    double[] regLabels = null;
    int[] classifyLabels = null;
    if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
        regLabels = new double[nLocalRow];
    } else {
        classifyLabels = new int[nLocalRow];
    }
    int agg = 0;
    for (int iter = 0; iter < nLocalRow; ++iter) {
        for (int i = 0; i < nFeatureCol; ++i) {
            data[i * nLocalRow + agg] = (int) dataRows.get(iter).getField(i);
        }
        if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
            regLabels[agg] = (double) dataRows.get(iter).getField(nFeatureCol);
        } else {
            classifyLabels[agg] = (int) dataRows.get(iter).getField(nFeatureCol);
        }
        agg++;
    }
    treeObj.setFeatures(data);
    if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
        treeObj.setLabels(regLabels);
    } else {
        treeObj.setLabels(classifyLabels);
    }
    double[] histBuffer = new double[treeObj.getMaxHistBufferSize()];
    context.putObj("allReduce", histBuffer);
    treeObj.setHist(histBuffer);
    treeObj.initialRoot();
    context.putObj("treeObj", treeObj);
}
Also used : ArrayList(java.util.ArrayList) RandomForestTrainParams(com.alibaba.alink.params.classification.RandomForestTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) QuantileDiscretizerModelDataConverter(com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter) FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) Row(org.apache.flink.types.Row)

Aggregations

FeatureMeta (com.alibaba.alink.operator.common.tree.FeatureMeta)8 EpsilonApproQuantile (com.alibaba.alink.operator.common.tree.parallelcart.EpsilonApproQuantile)3 Params (org.apache.flink.ml.api.misc.param.Params)3 Row (org.apache.flink.types.Row)3 QuantileDiscretizerModelDataConverter (com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter)2 IterativeComQueue (com.alibaba.alink.common.comqueue.IterativeComQueue)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 Node (com.alibaba.alink.operator.common.tree.Node)1 Preprocessing (com.alibaba.alink.operator.common.tree.Preprocessing)1 TreeModelDataConverter (com.alibaba.alink.operator.common.tree.TreeModelDataConverter)1 BaseGbdtTrainBatchOp (com.alibaba.alink.operator.common.tree.parallelcart.BaseGbdtTrainBatchOp)1 AllReduceT (com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT)1 ReduceScatter (com.alibaba.alink.operator.common.tree.parallelcart.communication.ReduceScatter)1 PaiCriteria (com.alibaba.alink.operator.common.tree.parallelcart.criteria.PaiCriteria)1 LossType (com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType)1 LossUtils (com.alibaba.alink.operator.common.tree.parallelcart.loss.LossUtils)1 GbdtTrainParams (com.alibaba.alink.params.classification.GbdtTrainParams)1 RandomForestTrainParams (com.alibaba.alink.params.classification.RandomForestTrainParams)1 LambdaMartNdcgParams (com.alibaba.alink.params.regression.LambdaMartNdcgParams)1 ArrayList (java.util.ArrayList)1