Search in sources :

Example 1 with PipelineModelMapper

use of com.alibaba.alink.common.mapper.PipelineModelMapper in project Alink by alibaba.

the class PipelineModel method collectLocalPredictor.

@Override
public LocalPredictor collectLocalPredictor(TableSchema inputSchema) throws Exception {
    if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
        BatchOperator<?> modelSave = ModelExporterUtils.serializePipelineStages(Arrays.asList(transformers), params);
        TableSchema extendSchema = getOutSchema(this, inputSchema);
        BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(modelSave.getMLEnvironmentId(), modelSave.getDataSet().map(new PipelineModelMapper.ExtendPipelineModelRow(extendSchema.getFieldNames().length + 1)), PipelineModelMapper.getExtendModelSchema(modelSave.getSchema(), extendSchema.getFieldNames(), extendSchema.getFieldTypes())));
        List<Row> modelRows = model.collect();
        ModelMapper mapper = new PipelineModelMapper(model.getSchema(), inputSchema, this.params);
        mapper.loadModel(modelRows);
        return new LocalPredictor(mapper);
    }
    if (null == transformers || transformers.length == 0) {
        throw new RuntimeException("PipelineModel is empty.");
    }
    List<BatchOperator<?>> allModelData = new ArrayList<>();
    for (TransformerBase<?> transformer : transformers) {
        if (!(transformer instanceof LocalPredictable)) {
            throw new RuntimeException(transformer.getClass().toString() + " not support local predict.");
        }
        if (transformer instanceof ModelBase) {
            allModelData.add(((ModelBase<?>) transformer).getModelData());
        }
    }
    List<List<Row>> allModelDataRows;
    if (!allModelData.isEmpty()) {
        allModelDataRows = BatchOperator.collect(allModelData.toArray(new BatchOperator<?>[0]));
    } else {
        allModelDataRows = new ArrayList<>();
    }
    TableSchema schema = inputSchema;
    int numMapperModel = 0;
    List<Mapper> mappers = new ArrayList<>();
    for (TransformerBase<?> transformer : transformers) {
        Mapper mapper;
        if (transformer instanceof MapModel) {
            mapper = ModelExporterUtils.createMapperFromStage(transformer, ((MapModel<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
            numMapperModel++;
        } else if (transformer instanceof BaseRecommender) {
            mapper = ModelExporterUtils.createMapperFromStage(transformer, ((BaseRecommender<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
            numMapperModel++;
        } else {
            mapper = ModelExporterUtils.createMapperFromStage(transformer, null, schema, null);
        }
        mappers.add(mapper);
        schema = mapper.getOutputSchema();
    }
    return new LocalPredictor(mappers.toArray(new Mapper[0]));
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) ArrayList(java.util.ArrayList) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) Mapper(com.alibaba.alink.common.mapper.Mapper) BaseRecommender(com.alibaba.alink.pipeline.recommendation.BaseRecommender) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.flink.types.Row)

Example 2 with PipelineModelMapper

use of com.alibaba.alink.common.mapper.PipelineModelMapper in project Alink by alibaba.

the class ModelExporterUtils method loadLocalPredictorFromPipelineModel.

static LocalPredictor loadLocalPredictorFromPipelineModel(FilePath filePath, TableSchema inputSchema) throws Exception {
    Tuple2<TableSchema, List<Row>> readed = AkUtils.readFromPath(filePath);
    Tuple2<TableSchema, Row> schemaAndMeta = ModelExporterUtils.loadMetaFromAkFile(filePath);
    Tuple2<StageNode[], Params> stagesAndParams = ModelExporterUtils.deserializePipelineStagesAndParamsFromMeta(schemaAndMeta.f1, schemaAndMeta.f0);
    Mapper[] mappers = loadMapperListFromStages(readed.f1, readed.f0, inputSchema).getMappers();
    Params params = stagesAndParams.f1;
    if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
        TableSchema extendSchema = mappers[mappers.length - 1].getOutputSchema();
        params.set(PipelineModelMapper.PIPELINE_TRANSFORM_OUT_COL_NAMES, extendSchema.getFieldNames());
        params.set(PipelineModelMapper.PIPELINE_TRANSFORM_OUT_COL_TYPES, FlinkTypeConverter.getTypeString(extendSchema.getFieldTypes()));
        PipelineModelMapper pipelineModelMapper = new PipelineModelMapper(readed.f0, inputSchema, params);
        pipelineModelMapper.loadModel(readed.f1);
        return new LocalPredictor(pipelineModelMapper);
    }
    return new LocalPredictor(mappers);
}
Also used : ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) ComboModelMapper(com.alibaba.alink.common.mapper.ComboModelMapper) ComboMapper(com.alibaba.alink.common.mapper.ComboMapper) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) Mapper(com.alibaba.alink.common.mapper.Mapper) TableSchema(org.apache.flink.table.api.TableSchema) ModelStreamScanParams(com.alibaba.alink.params.ModelStreamScanParams) Params(org.apache.flink.ml.api.misc.param.Params) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) Row(org.apache.flink.types.Row) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper)

Aggregations

Mapper (com.alibaba.alink.common.mapper.Mapper)2 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)2 PipelineModelMapper (com.alibaba.alink.common.mapper.PipelineModelMapper)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 TableSchema (org.apache.flink.table.api.TableSchema)2 Row (org.apache.flink.types.Row)2 ComboMapper (com.alibaba.alink.common.mapper.ComboMapper)1 ComboModelMapper (com.alibaba.alink.common.mapper.ComboModelMapper)1 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 ModelStreamScanParams (com.alibaba.alink.params.ModelStreamScanParams)1 BaseRecommender (com.alibaba.alink.pipeline.recommendation.BaseRecommender)1 LinkedList (java.util.LinkedList)1 Params (org.apache.flink.ml.api.misc.param.Params)1