use of com.alibaba.alink.common.mapper.PipelineModelMapper in project Alink by alibaba.
the class PipelineModel method collectLocalPredictor.
@Override
public LocalPredictor collectLocalPredictor(TableSchema inputSchema) throws Exception {
if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
BatchOperator<?> modelSave = ModelExporterUtils.serializePipelineStages(Arrays.asList(transformers), params);
TableSchema extendSchema = getOutSchema(this, inputSchema);
BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(modelSave.getMLEnvironmentId(), modelSave.getDataSet().map(new PipelineModelMapper.ExtendPipelineModelRow(extendSchema.getFieldNames().length + 1)), PipelineModelMapper.getExtendModelSchema(modelSave.getSchema(), extendSchema.getFieldNames(), extendSchema.getFieldTypes())));
List<Row> modelRows = model.collect();
ModelMapper mapper = new PipelineModelMapper(model.getSchema(), inputSchema, this.params);
mapper.loadModel(modelRows);
return new LocalPredictor(mapper);
}
if (null == transformers || transformers.length == 0) {
throw new RuntimeException("PipelineModel is empty.");
}
List<BatchOperator<?>> allModelData = new ArrayList<>();
for (TransformerBase<?> transformer : transformers) {
if (!(transformer instanceof LocalPredictable)) {
throw new RuntimeException(transformer.getClass().toString() + " not support local predict.");
}
if (transformer instanceof ModelBase) {
allModelData.add(((ModelBase<?>) transformer).getModelData());
}
}
List<List<Row>> allModelDataRows;
if (!allModelData.isEmpty()) {
allModelDataRows = BatchOperator.collect(allModelData.toArray(new BatchOperator<?>[0]));
} else {
allModelDataRows = new ArrayList<>();
}
TableSchema schema = inputSchema;
int numMapperModel = 0;
List<Mapper> mappers = new ArrayList<>();
for (TransformerBase<?> transformer : transformers) {
Mapper mapper;
if (transformer instanceof MapModel) {
mapper = ModelExporterUtils.createMapperFromStage(transformer, ((MapModel<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
numMapperModel++;
} else if (transformer instanceof BaseRecommender) {
mapper = ModelExporterUtils.createMapperFromStage(transformer, ((BaseRecommender<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
numMapperModel++;
} else {
mapper = ModelExporterUtils.createMapperFromStage(transformer, null, schema, null);
}
mappers.add(mapper);
schema = mapper.getOutputSchema();
}
return new LocalPredictor(mappers.toArray(new Mapper[0]));
}
use of com.alibaba.alink.common.mapper.PipelineModelMapper in project Alink by alibaba.
the class ModelExporterUtils method loadLocalPredictorFromPipelineModel.
static LocalPredictor loadLocalPredictorFromPipelineModel(FilePath filePath, TableSchema inputSchema) throws Exception {
Tuple2<TableSchema, List<Row>> readed = AkUtils.readFromPath(filePath);
Tuple2<TableSchema, Row> schemaAndMeta = ModelExporterUtils.loadMetaFromAkFile(filePath);
Tuple2<StageNode[], Params> stagesAndParams = ModelExporterUtils.deserializePipelineStagesAndParamsFromMeta(schemaAndMeta.f1, schemaAndMeta.f0);
Mapper[] mappers = loadMapperListFromStages(readed.f1, readed.f0, inputSchema).getMappers();
Params params = stagesAndParams.f1;
if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
TableSchema extendSchema = mappers[mappers.length - 1].getOutputSchema();
params.set(PipelineModelMapper.PIPELINE_TRANSFORM_OUT_COL_NAMES, extendSchema.getFieldNames());
params.set(PipelineModelMapper.PIPELINE_TRANSFORM_OUT_COL_TYPES, FlinkTypeConverter.getTypeString(extendSchema.getFieldTypes()));
PipelineModelMapper pipelineModelMapper = new PipelineModelMapper(readed.f0, inputSchema, params);
pipelineModelMapper.loadModel(readed.f1);
return new LocalPredictor(pipelineModelMapper);
}
return new LocalPredictor(mappers);
}
Aggregations