Search in sources :

Example 1 with BroadcastVariableModelSource

use of com.alibaba.alink.common.model.BroadcastVariableModelSource in project Alink by alibaba.

the class FlatModelMapBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    checkOpSize(2, inputs);
    try {
        BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
        FlatModelMapper flatMapper = this.mapperBuilder.apply(inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
        DataSet<Row> resultRows;
        if (flatMapper instanceof IterableModelLoader) {
            final long handler = IterTaskObjKeeper.getNewHandle();
            DataSet<Row> modelRows = inputs[0].getDataSet();
            DataSet<Row> distributedModelRows = modelRows.flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() {

                private static final long serialVersionUID = 3544759002096859673L;

                int numTask;

                @Override
                public void open(Configuration parameters) {
                    numTask = getRuntimeContext().getNumberOfParallelSubtasks();
                }

                @Override
                public void flatMap(Row value, Collector<Tuple2<Integer, Row>> out) {
                    for (int i = 0; i < numTask; ++i) {
                        out.collect(Tuple2.of(i, value));
                    }
                }
            }).returns(new TupleTypeInfo<>(Types.INT, modelRows.getType())).partitionCustom(new Partitioner<Integer>() {

                private static final long serialVersionUID = -2924355974935165844L;

                @Override
                public int partition(Integer key, int numPartitions) {
                    return key;
                }
            }, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() {

                private static final long serialVersionUID = 8884296007768771379L;

                @Override
                public Row map(Tuple2<Integer, Row> value) throws Exception {
                    return value.f1;
                }
            }).returns(modelRows.getType());
            DataSet<Integer> barrier = distributedModelRows.mapPartition(new RichMapPartitionFunction<Row, Integer>() {

                private static final long serialVersionUID = 2358845952757630826L;

                @Override
                public void mapPartition(Iterable<Row> values, Collector<Integer> out) {
                    int taskId = getRuntimeContext().getIndexOfThisSubtask();
                    ((IterableModelLoader) flatMapper).loadIterableModel(values);
                    IterTaskObjKeeper.put(handler, taskId, flatMapper);
                }
            });
            resultRows = inputs[1].getDataSet().flatMap(new IterableModelLoaderFlatModelMapperAdapter(handler)).withBroadcastSet(barrier, "barrier");
        } else {
            DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
            resultRows = inputs[1].getDataSet().flatMap(new FlatModelMapperAdapter(flatMapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
        }
        TableSchema outputSchema = flatMapper.getOutputSchema();
        this.setOutput(resultRows, outputSchema);
        return (T) this;
    } catch (Exception ex) {
        throw new RuntimeException(ex);
    }
}
Also used : Configuration(org.apache.flink.configuration.Configuration) TableSchema(org.apache.flink.table.api.TableSchema) FlatModelMapperAdapter(com.alibaba.alink.common.mapper.FlatModelMapperAdapter) IterableModelLoaderFlatModelMapperAdapter(com.alibaba.alink.common.mapper.IterableModelLoaderFlatModelMapperAdapter) IterableModelLoader(com.alibaba.alink.common.mapper.IterableModelLoader) FlatModelMapper(com.alibaba.alink.common.mapper.FlatModelMapper) RichFlatMapFunction(org.apache.flink.api.common.functions.RichFlatMapFunction) MapFunction(org.apache.flink.api.common.functions.MapFunction) BroadcastVariableModelSource(com.alibaba.alink.common.model.BroadcastVariableModelSource) TupleTypeInfo(org.apache.flink.api.java.typeutils.TupleTypeInfo) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Row(org.apache.flink.types.Row) IterableModelLoaderFlatModelMapperAdapter(com.alibaba.alink.common.mapper.IterableModelLoaderFlatModelMapperAdapter)

Example 2 with BroadcastVariableModelSource

use of com.alibaba.alink.common.model.BroadcastVariableModelSource in project Alink by alibaba.

the class ModelMapBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    checkOpSize(2, inputs);
    try {
        final ModelMapper mapper = this.mapperBuilder.apply(inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
        DataSet<Row> resultRows = null;
        if (mapper instanceof IterableModelLoader) {
            final long handler = IterTaskObjKeeper.getNewHandle();
            DataSet<Row> modelRows = inputs[0].getDataSet();
            DataSet<Row> distributedModelRows = modelRows.flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() {

                private static final long serialVersionUID = 3544759002096859673L;

                int numTask;

                @Override
                public void open(Configuration parameters) {
                    numTask = getRuntimeContext().getNumberOfParallelSubtasks();
                }

                @Override
                public void flatMap(Row value, Collector<Tuple2<Integer, Row>> out) {
                    for (int i = 0; i < numTask; ++i) {
                        out.collect(Tuple2.of(i, value));
                    }
                }
            }).returns(new TupleTypeInfo<>(Types.INT, modelRows.getType())).partitionCustom(new Partitioner<Integer>() {

                private static final long serialVersionUID = -2924355974935165844L;

                @Override
                public int partition(Integer key, int numPartitions) {
                    return key;
                }
            }, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() {

                private static final long serialVersionUID = 8884296007768771379L;

                @Override
                public Row map(Tuple2<Integer, Row> value) throws Exception {
                    return value.f1;
                }
            }).returns(modelRows.getType());
            DataSet<Integer> barrier = distributedModelRows.mapPartition(new RichMapPartitionFunction<Row, Integer>() {

                private static final long serialVersionUID = 2358845952757630826L;

                @Override
                public void mapPartition(Iterable<Row> values, Collector<Integer> out) {
                    int taskId = getRuntimeContext().getIndexOfThisSubtask();
                    ((IterableModelLoader) mapper).loadIterableModel(values);
                    IterTaskObjKeeper.put(handler, taskId, mapper);
                }
            });
            if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
                resultRows = inputs[1].getDataSet().map(new IterableModelLoaderModelMapperAdapter(handler)).withBroadcastSet(barrier, "barrier");
            } else {
                resultRows = inputs[1].getDataSet().flatMap(new IterableModelLoaderModelMapperAdapterMT(handler, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(barrier, "barrier");
            }
        } else {
            final BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
            DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
            if (ModelStreamUtils.useModelStreamFile(getParams())) {
                resultRows = inputs[1].getDataSet().map(new RichMapFunction<Row, Row>() {

                    ModelStreamModelMapperAdapter modelStreamModelMapper;

                    @Override
                    public void open(Configuration parameters) throws Exception {
                        super.open(parameters);
                        List<Row> modelRows = modelSource.getModelRows(getRuntimeContext());
                        mapper.loadModel(modelRows);
                        mapper.open();
                        modelStreamModelMapper = new ModelStreamModelMapperAdapter(mapper);
                    }

                    @Override
                    public Row map(Row value) throws Exception {
                        return modelStreamModelMapper.map(value);
                    }
                }).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
            } else if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
                resultRows = inputs[1].getDataSet().map(new ModelMapperAdapter(mapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
            } else {
                resultRows = inputs[1].getDataSet().flatMap(new ModelMapperAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
            }
        }
        TableSchema outputSchema = mapper.getOutputSchema();
        this.setOutput(resultRows, outputSchema);
        return (T) this;
    } catch (Exception ex) {
        throw new RuntimeException(ex);
    }
}
Also used : IterableModelLoaderModelMapperAdapterMT(com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapterMT) Configuration(org.apache.flink.configuration.Configuration) TableSchema(org.apache.flink.table.api.TableSchema) ModelStreamModelMapperAdapter(com.alibaba.alink.common.mapper.ModelStreamModelMapperAdapter) IterableModelLoader(com.alibaba.alink.common.mapper.IterableModelLoader) RichFlatMapFunction(org.apache.flink.api.common.functions.RichFlatMapFunction) MapFunction(org.apache.flink.api.common.functions.MapFunction) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) IterableModelLoaderModelMapperAdapter(com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapter) BroadcastVariableModelSource(com.alibaba.alink.common.model.BroadcastVariableModelSource) IterableModelLoaderModelMapperAdapterMT(com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapterMT) ModelMapperAdapterMT(com.alibaba.alink.common.mapper.ModelMapperAdapterMT) List(java.util.List) TupleTypeInfo(org.apache.flink.api.java.typeutils.TupleTypeInfo) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) IterableModelLoaderModelMapperAdapterMT(com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapterMT) ModelMapperAdapterMT(com.alibaba.alink.common.mapper.ModelMapperAdapterMT) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Row(org.apache.flink.types.Row) IterableModelLoaderModelMapperAdapter(com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapter) ModelStreamModelMapperAdapter(com.alibaba.alink.common.mapper.ModelStreamModelMapperAdapter) ModelMapperAdapter(com.alibaba.alink.common.mapper.ModelMapperAdapter)

Example 3 with BroadcastVariableModelSource

use of com.alibaba.alink.common.model.BroadcastVariableModelSource in project Alink by alibaba.

the class BaseRecommBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    checkOpSize(2, inputs);
    try {
        BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
        RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
        DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
        DataSet<Row> resultRows;
        if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
            resultRows = inputs[1].getDataSet().map(new RecommAdapter(mapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
        } else {
            resultRows = inputs[1].getDataSet().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
        }
        TableSchema outputSchema = mapper.getOutputSchema();
        this.setOutput(resultRows, outputSchema);
        return (T) this;
    } catch (Exception ex) {
        throw new RuntimeException(ex);
    }
}
Also used : RecommAdapter(com.alibaba.alink.operator.common.recommendation.RecommAdapter) BroadcastVariableModelSource(com.alibaba.alink.common.model.BroadcastVariableModelSource) RecommAdapterMT(com.alibaba.alink.operator.common.recommendation.RecommAdapterMT) TableSchema(org.apache.flink.table.api.TableSchema) RecommAdapterMT(com.alibaba.alink.operator.common.recommendation.RecommAdapterMT) Row(org.apache.flink.types.Row) RecommMapper(com.alibaba.alink.operator.common.recommendation.RecommMapper)

Aggregations

BroadcastVariableModelSource (com.alibaba.alink.common.model.BroadcastVariableModelSource)3 TableSchema (org.apache.flink.table.api.TableSchema)3 Row (org.apache.flink.types.Row)3 IterableModelLoader (com.alibaba.alink.common.mapper.IterableModelLoader)2 MapFunction (org.apache.flink.api.common.functions.MapFunction)2 RichFlatMapFunction (org.apache.flink.api.common.functions.RichFlatMapFunction)2 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)2 TupleTypeInfo (org.apache.flink.api.java.typeutils.TupleTypeInfo)2 Configuration (org.apache.flink.configuration.Configuration)2 FlatModelMapper (com.alibaba.alink.common.mapper.FlatModelMapper)1 FlatModelMapperAdapter (com.alibaba.alink.common.mapper.FlatModelMapperAdapter)1 IterableModelLoaderFlatModelMapperAdapter (com.alibaba.alink.common.mapper.IterableModelLoaderFlatModelMapperAdapter)1 IterableModelLoaderModelMapperAdapter (com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapter)1 IterableModelLoaderModelMapperAdapterMT (com.alibaba.alink.common.mapper.IterableModelLoaderModelMapperAdapterMT)1 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)1 ModelMapperAdapter (com.alibaba.alink.common.mapper.ModelMapperAdapter)1 ModelMapperAdapterMT (com.alibaba.alink.common.mapper.ModelMapperAdapterMT)1 ModelStreamModelMapperAdapter (com.alibaba.alink.common.mapper.ModelStreamModelMapperAdapter)1 RecommAdapter (com.alibaba.alink.operator.common.recommendation.RecommAdapter)1 RecommAdapterMT (com.alibaba.alink.operator.common.recommendation.RecommAdapterMT)1