Search in sources :

Example 1 with MapperChain

use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.

the class TFTableModelClassificationModelMapper method loadFromModelData.

protected void loadFromModelData(TFTableModelClassificationModelData modelData, TableSchema modelSchema) {
    Params meta = modelData.getMeta();
    String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
    TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
    String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
    TableSchema dataSchema = getDataSchema();
    if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
        String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
        TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
        MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
        mappers.addAll(Arrays.asList(mapperList.getMappers()));
        dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
    }
    String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
    String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
    Params tfModelMapperParams = new Params();
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
    tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
    tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, reservedCols);
    tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
    if (null != modelData.getTfModelZipPath()) {
        tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
    } else {
        tfModelMapper.loadModel(modelData.getTfModelRows());
    }
    mappers.add(tfModelMapper);
    predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
    sortedLabels = modelData.getSortedLabels();
    isOutputLogits = meta.get(TFModelDataConverterUtils.IS_OUTPUT_LOGITS);
}
Also used : MapperChain(com.alibaba.alink.common.mapper.MapperChain) TableSchema(org.apache.flink.table.api.TableSchema) TFTableModelClassificationPredictParams(com.alibaba.alink.params.classification.TFTableModelClassificationPredictParams) TFTableModelPredictParams(com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelPredictModelMapper(com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper)

Example 2 with MapperChain

use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.

the class LocalPredictor method merge.

public void merge(LocalPredictor otherPredictor) {
    this.mappers.addAll(otherPredictor.mappers);
    this.mapperList = new MapperChain(this.mappers.toArray(new Mapper[0]));
}
Also used : MapperChain(com.alibaba.alink.common.mapper.MapperChain)

Example 3 with MapperChain

use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.

the class ModelExporterUtils method loadMapperListFromStages.

// mapper not open.
public static MapperChain loadMapperListFromStages(List<Tuple3<PipelineStageBase<?>, TableSchema, List<Row>>> stages, TableSchema inputSchema) {
    TableSchema outSchema = inputSchema;
    List<Mapper> mappers = new ArrayList<>();
    for (Tuple3<PipelineStageBase<?>, TableSchema, List<Row>> stageTuple3 : stages) {
        Mapper mapper = createMapperFromStage(stageTuple3.f0, stageTuple3.f1, outSchema, stageTuple3.f2);
        mappers.add(mapper);
        outSchema = mapper.getOutputSchema();
    }
    return new MapperChain(mappers.toArray(new Mapper[0]));
}
Also used : MapperChain(com.alibaba.alink.common.mapper.MapperChain) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) ComboModelMapper(com.alibaba.alink.common.mapper.ComboModelMapper) ComboMapper(com.alibaba.alink.common.mapper.ComboMapper) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) Mapper(com.alibaba.alink.common.mapper.Mapper) TableSchema(org.apache.flink.table.api.TableSchema) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList)

Example 4 with MapperChain

use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.

the class TFTableModelRegressionModelMapper method loadFromModelData.

protected void loadFromModelData(TFTableModelRegressionModelData modelData, TableSchema modelSchema) {
    Params meta = modelData.getMeta();
    String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
    // TypeInformation <?> tfOutputSignatureType = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_TYPE);
    TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
    String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
    TableSchema dataSchema = getDataSchema();
    if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
        String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
        TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
        MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
        mappers.addAll(Arrays.asList(mapperList.getMappers()));
        dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
    }
    String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
    String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
    String[] tfReservedCols = (String[]) ArrayUtils.addAll(reservedCols, new String[] { predCol });
    Params tfModelMapperParams = new Params();
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
    tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
    tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, tfReservedCols);
    tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
    if (null != modelData.getTfModelZipPath()) {
        tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
    } else {
        tfModelMapper.loadModel(modelData.getTfModelRows());
    }
    mappers.add(tfModelMapper);
    predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
}
Also used : MapperChain(com.alibaba.alink.common.mapper.MapperChain) TableSchema(org.apache.flink.table.api.TableSchema) TFTableModelClassificationPredictParams(com.alibaba.alink.params.classification.TFTableModelClassificationPredictParams) TFTableModelPredictParams(com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelPredictModelMapper(com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper)

Aggregations

MapperChain (com.alibaba.alink.common.mapper.MapperChain)4 TableSchema (org.apache.flink.table.api.TableSchema)3 TFTableModelPredictModelMapper (com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper)2 TFTableModelClassificationPredictParams (com.alibaba.alink.params.classification.TFTableModelClassificationPredictParams)2 TFTableModelPredictParams (com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams)2 Params (org.apache.flink.ml.api.misc.param.Params)2 ComboMapper (com.alibaba.alink.common.mapper.ComboMapper)1 ComboModelMapper (com.alibaba.alink.common.mapper.ComboModelMapper)1 Mapper (com.alibaba.alink.common.mapper.Mapper)1 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)1 PipelineModelMapper (com.alibaba.alink.common.mapper.PipelineModelMapper)1 ArrayList (java.util.ArrayList)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1