Search in sources :

Example 1 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class BaseGbdtTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    LOG.info("gbdt train start");
    if (!Preprocessing.isSparse(getParams())) {
        getParams().set(HasCategoricalCols.CATEGORICAL_COLS, TableUtil.getCategoricalCols(in.getSchema(), getParams().get(GbdtTrainParams.FEATURE_COLS), getParams().contains(GbdtTrainParams.CATEGORICAL_COLS) ? getParams().get(GbdtTrainParams.CATEGORICAL_COLS) : null));
    }
    LossType loss = getParams().get(LossUtils.LOSS_TYPE);
    getParams().set(ALGO_TYPE, LossUtils.lossTypeToInt(loss));
    rewriteLabelType(in.getSchema(), getParams());
    if (!Preprocessing.isSparse(getParams())) {
        getParams().set(ModelParamName.FEATURE_TYPES, FlinkTypeConverter.getTypeString(TableUtil.findColTypes(in.getSchema(), getParams().get(GbdtTrainParams.FEATURE_COLS))));
    }
    if (LossUtils.isRanking(getParams().get(LossUtils.LOSS_TYPE))) {
        if (!getParams().contains(LambdaMartNdcgParams.GROUP_COL)) {
            throw new IllegalArgumentException("Group column should be set in ranking loss function.");
        }
    }
    String[] trainColNames = trainColsWithGroup();
    // check label if has null value or not.
    final String labelColName = this.getParams().get(HasLabelCol.LABEL_COL);
    final int labelColIdx = TableUtil.findColIndex(in.getSchema(), labelColName);
    in = new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), in.getDataSet().map(new MapFunction<Row, Row>() {

        @Override
        public Row map(Row row) throws Exception {
            if (null == row.getField(labelColIdx)) {
                throw new RuntimeException("label col has null values.");
            }
            return row;
        }
    }), in.getSchema())).setMLEnvironmentId(in.getMLEnvironmentId());
    in = Preprocessing.select(in, trainColNames);
    DataSet<Object[]> labels = Preprocessing.generateLabels(in, getParams(), LossUtils.isRegression(loss) || LossUtils.isRanking(loss));
    if (LossUtils.isClassification(loss)) {
        labels = labels.map(new CheckNumLabels4BinaryClassifier());
    }
    DataSet<Row> trainDataSet;
    BatchOperator<?> stringIndexerModel;
    BatchOperator<?> quantileModel;
    if (getParams().get(USE_ONEHOT)) {
        // create empty string indexer model.
        stringIndexerModel = Preprocessing.generateStringIndexerModel(in, new Params());
        // create empty quantile model.
        quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, new Params().set(HasFeatureCols.FEATURE_COLS, new String[] {}).set(HasCategoricalCols.CATEGORICAL_COLS, new String[] {}));
        trainDataSet = Preprocessing.castLabel(in, getParams(), labels, LossUtils.isRegression(loss) || LossUtils.isRanking(loss)).getDataSet();
    } else if (getParams().get(USE_EPSILON_APPRO_QUANTILE)) {
        // create string indexer model
        stringIndexerModel = Preprocessing.generateStringIndexerModel(in, getParams());
        // create empty quantile model
        quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, new Params().set(HasFeatureCols.FEATURE_COLS, new String[] {}).set(HasCategoricalCols.CATEGORICAL_COLS, new String[] {}));
        trainDataSet = Preprocessing.castLabel(Preprocessing.isSparse(getParams()) ? in : Preprocessing.castContinuousCols(Preprocessing.castCategoricalCols(in, stringIndexerModel, getParams()), getParams()), getParams(), labels, LossUtils.isRegression(loss) || LossUtils.isRanking(loss)).getDataSet();
    } else {
        stringIndexerModel = Preprocessing.generateStringIndexerModel(in, getParams());
        quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, getParams());
        trainDataSet = Preprocessing.castLabel(Preprocessing.castToQuantile(Preprocessing.isSparse(getParams()) ? in : Preprocessing.castContinuousCols(Preprocessing.castCategoricalCols(in, stringIndexerModel, getParams()), getParams()), quantileModel, getParams()), getParams(), labels, LossUtils.isRegression(loss) || LossUtils.isRanking(loss)).getDataSet();
    }
    if (LossUtils.isRanking(getParams().get(LossUtils.LOSS_TYPE))) {
        trainDataSet = trainDataSet.partitionCustom(new Partitioner<Number>() {

            private static final long serialVersionUID = -7790649477852624964L;

            @Override
            public int partition(Number key, int numPartitions) {
                return (int) (key.longValue() % numPartitions);
            }
        }, 0);
    }
    DataSet<Tuple2<Double, Long>> sum = trainDataSet.mapPartition(new MapPartitionFunction<Row, Tuple2<Double, Long>>() {

        private static final long serialVersionUID = -8333738060239409640L;

        @Override
        public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<Double, Long>> collector) throws Exception {
            double sum = 0.;
            long cnt = 0;
            for (Row row : iterable) {
                sum += ((Number) row.getField(row.getArity() - 1)).doubleValue();
                cnt++;
            }
            collector.collect(Tuple2.of(sum, cnt));
        }
    }).reduce(new ReduceFunction<Tuple2<Double, Long>>() {

        private static final long serialVersionUID = -6464200385237876961L;

        @Override
        public Tuple2<Double, Long> reduce(Tuple2<Double, Long> t0, Tuple2<Double, Long> t1) throws Exception {
            return Tuple2.of(t0.f0 + t1.f0, t0.f1 + t1.f1);
        }
    });
    DataSet<FeatureMeta> featureMetas;
    if (getParams().get(USE_ONEHOT)) {
        featureMetas = DataUtil.createOneHotFeatureMeta(trainDataSet, getParams(), trainColNames);
    } else if (getParams().get(USE_EPSILON_APPRO_QUANTILE)) {
        featureMetas = DataUtil.createEpsilonApproQuantileFeatureMeta(trainDataSet, stringIndexerModel.getDataSet(), getParams(), trainColNames, getMLEnvironmentId());
    } else {
        featureMetas = DataUtil.createFeatureMetas(quantileModel.getDataSet(), stringIndexerModel.getDataSet(), getParams());
    }
    {
        getParams().set(BoosterType.BOOSTER_TYPE, BoosterType.HESSION_BASE);
        getParams().set(CriteriaType.CRITERIA_TYPE, CriteriaType.valueOf(getParams().get(GbdtTrainParams.CRITERIA).toString()));
        if (getParams().get(GbdtTrainParams.NEWTON_STEP)) {
            getParams().set(LeafScoreUpdaterType.LEAF_SCORE_UPDATER_TYPE, LeafScoreUpdaterType.NEWTON_SINGLE_STEP_UPDATER);
        } else {
            getParams().set(LeafScoreUpdaterType.LEAF_SCORE_UPDATER_TYPE, LeafScoreUpdaterType.WEIGHT_AVG_UPDATER);
        }
    }
    IterativeComQueue comQueue = new IterativeComQueue().initWithPartitionedData("trainData", trainDataSet).initWithBroadcastData("gbdt.y.sum", sum).initWithBroadcastData("quantileModel", quantileModel.getDataSet()).initWithBroadcastData("stringIndexerModel", stringIndexerModel.getDataSet()).initWithBroadcastData("labels", labels).initWithBroadcastData("featureMetas", featureMetas).add(new InitBoostingObjs(getParams())).add(new Boosting()).add(new Bagging()).add(new InitTreeObjs());
    if (getParams().get(USE_EPSILON_APPRO_QUANTILE)) {
        comQueue.add(new BuildLocalSketch()).add(new AllReduceT<>(BuildLocalSketch.SKETCH, BuildLocalSketch.FEATURE_SKETCH_LENGTH, new BuildLocalSketch.SketchReducer(getParams()), EpsilonApproQuantile.WQSummary.class)).add(new FinalizeBuildSketch());
    }
    comQueue.add(new ConstructLocalHistogram()).add(new ReduceScatter("histogram", "histogram", "recvcnts", AllReduce.SUM)).add(new CalcFeatureGain()).add(new AllReduceT<>("best", "bestLength", new NodeReducer(), Node.class)).add(new SplitInstances()).add(new UpdateLeafScore()).add(new UpdatePredictionScore()).setCompareCriterionOfNode0(new TerminateCriterion()).closeWith(new SaveModel(getParams()));
    DataSet<Row> model = comQueue.exec();
    setOutput(model, new TreeModelDataConverter(FlinkTypeConverter.getFlinkType(getParams().get(ModelParamName.LABEL_TYPE_NAME))).getModelSchema());
    this.setSideOutputTables(new Table[] { DataSetConversionUtil.toTable(getMLEnvironmentId(), model.reduceGroup(new TreeModelDataConverter.FeatureImportanceReducer()), new String[] { getParams().get(TreeModelDataConverter.IMPORTANCE_FIRST_COL), getParams().get(TreeModelDataConverter.IMPORTANCE_SECOND_COL) }, new TypeInformation[] { Types.STRING, Types.DOUBLE }) });
    return (T) this;
}
Also used : TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) MapPartitionFunction(org.apache.flink.api.common.functions.MapPartitionFunction) FeatureMeta(com.alibaba.alink.operator.common.tree.FeatureMeta) IterativeComQueue(com.alibaba.alink.common.comqueue.IterativeComQueue) LambdaMartNdcgParams(com.alibaba.alink.params.regression.LambdaMartNdcgParams) GbdtTrainParams(com.alibaba.alink.params.classification.GbdtTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) TreeModelDataConverter(com.alibaba.alink.operator.common.tree.TreeModelDataConverter) Row(org.apache.flink.types.Row) ReduceScatter(com.alibaba.alink.operator.common.tree.parallelcart.communication.ReduceScatter) AllReduceT(com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT) AllReduceT(com.alibaba.alink.operator.common.tree.parallelcart.communication.AllReduceT) Collector(org.apache.flink.util.Collector) Partitioner(org.apache.flink.api.common.functions.Partitioner) Tuple2(org.apache.flink.api.java.tuple.Tuple2) LossType(com.alibaba.alink.operator.common.tree.parallelcart.loss.LossType)

Example 2 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class TestUtil method printTable.

public static void printTable(Table table) throws Exception {
    TableImpl tableImpl = (TableImpl) table;
    if (tableImpl.getTableEnvironment() instanceof StreamTableEnvironment) {
        new TableSourceStreamOp(table).print();
        StreamOperator.execute();
    } else {
        new TableSourceBatchOp(table).print();
    }
}
Also used : TableImpl(org.apache.flink.table.api.internal.TableImpl) StreamTableEnvironment(org.apache.flink.table.api.bridge.java.StreamTableEnvironment) TableSourceStreamOp(com.alibaba.alink.operator.stream.source.TableSourceStreamOp) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp)

Example 3 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class MaxAbsTest method test.

@Test
public void test() throws Exception {
    BatchOperator batchData = new TableSourceBatchOp(GenerateData.getBatchTable());
    StreamOperator streamData = new TableSourceStreamOp(GenerateData.getStreamTable());
    MaxAbsScalerModel model = new MaxAbsScaler().setSelectedCols("f0", "f1").setOutputCols("f0_1", "f1_1").fit(batchData);
    model.transform(batchData).lazyCollect();
    model.transform(streamData).print();
    MaxAbsScalerTrainBatchOp op = new MaxAbsScalerTrainBatchOp().setSelectedCols("f0", "f1").linkFrom(batchData);
    List<Row> rows = new MaxAbsScalerPredictBatchOp().linkFrom(op, batchData).collect();
    rows.sort(StandardScalerTest.compare);
    assertEquals(rows.get(0), Row.of(null, null));
    StandardScalerTest.assertRow(rows.get(1), Row.of(-0.25, -1.));
    StandardScalerTest.assertRow(rows.get(2), Row.of(0.25, 0.666));
    StandardScalerTest.assertRow(rows.get(3), Row.of(1.0, 0.6666));
    new MaxAbsScalerPredictStreamOp(op).linkFrom(streamData).print();
    StreamOperator.execute();
}
Also used : MaxAbsScaler(com.alibaba.alink.pipeline.dataproc.MaxAbsScaler) MaxAbsScalerPredictStreamOp(com.alibaba.alink.operator.stream.dataproc.MaxAbsScalerPredictStreamOp) MaxAbsScalerModel(com.alibaba.alink.pipeline.dataproc.MaxAbsScalerModel) TableSourceStreamOp(com.alibaba.alink.operator.stream.source.TableSourceStreamOp) Row(org.apache.flink.types.Row) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) StreamOperator(com.alibaba.alink.operator.stream.StreamOperator) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Test(org.junit.Test)

Example 4 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class VectorMaxAbsTest method testModelInfo.

@Test
public void testModelInfo() {
    BatchOperator batchData = new TableSourceBatchOp(GenerateData.getDenseBatch());
    VectorMaxAbsScalerTrainBatchOp trainOp = new VectorMaxAbsScalerTrainBatchOp().setSelectedCol("vec").linkFrom(batchData);
    VectorMaxAbsScalarModelInfo modelInfo = trainOp.getModelInfoBatchOp().collectModelInfo();
    System.out.println(modelInfo.getMaxAbs().length);
    System.out.println(modelInfo.toString());
}
Also used : TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) VectorMaxAbsScalarModelInfo(com.alibaba.alink.operator.common.dataproc.vector.VectorMaxAbsScalarModelInfo) Test(org.junit.Test)

Example 5 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class VectorMaxAbsTest method test.

@Test
public void test() throws Exception {
    BatchOperator batchData = new TableSourceBatchOp(GenerateData.getDenseBatch());
    StreamOperator streamData = new TableSourceStreamOp(GenerateData.getDenseStream());
    VectorMaxAbsScalerTrainBatchOp op = new VectorMaxAbsScalerTrainBatchOp().setSelectedCol("vec").linkFrom(batchData);
    List<Row> rows = new VectorMaxAbsScalerPredictBatchOp().setOutputCol("res").linkFrom(op, batchData).collect();
    VectorStandardScalerTest.assertDv(VectorUtil.getDenseVector(rows.get(0).getField(1)), new DenseVector(new double[] { 0.25, 0.6666 }));
    VectorStandardScalerTest.assertDv(VectorUtil.getDenseVector(rows.get(1).getField(1)), new DenseVector(new double[] { -0.25, -1. }));
    VectorStandardScalerTest.assertDv(VectorUtil.getDenseVector(rows.get(2).getField(1)), new DenseVector(new double[] { 1., 0.6666 }));
    new VectorMaxAbsScalerPredictStreamOp(op).setOutputCol("res").linkFrom(streamData).print();
    StreamOperator.execute();
}
Also used : TableSourceStreamOp(com.alibaba.alink.operator.stream.source.TableSourceStreamOp) Row(org.apache.flink.types.Row) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) StreamOperator(com.alibaba.alink.operator.stream.StreamOperator) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) DenseVector(com.alibaba.alink.common.linalg.DenseVector) VectorMaxAbsScalerPredictStreamOp(com.alibaba.alink.operator.stream.dataproc.vector.VectorMaxAbsScalerPredictStreamOp) Test(org.junit.Test)

Aggregations

TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)39 Row (org.apache.flink.types.Row)29 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)22 Test (org.junit.Test)18 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)12 TableSourceStreamOp (com.alibaba.alink.operator.stream.source.TableSourceStreamOp)10 Params (org.apache.flink.ml.api.misc.param.Params)10 List (java.util.List)8 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)8 StreamOperator (com.alibaba.alink.operator.stream.StreamOperator)6 ArrayList (java.util.ArrayList)6 TableSchema (org.apache.flink.table.api.TableSchema)6 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)5 Comparator (java.util.Comparator)4 HashMap (java.util.HashMap)4 MapFunction (org.apache.flink.api.common.functions.MapFunction)4 DataSet (org.apache.flink.api.java.DataSet)4 Mapper (com.alibaba.alink.common.mapper.Mapper)3 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)3 PipelineModelMapper (com.alibaba.alink.common.mapper.PipelineModelMapper)3