Search in sources :

Example 1 with TFTableModelPredictModelMapper

use of com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper in project Alink by alibaba.

the class TFTableModelClassificationModelMapper method loadFromModelData.

protected void loadFromModelData(TFTableModelClassificationModelData modelData, TableSchema modelSchema) {
    Params meta = modelData.getMeta();
    String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
    TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
    String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
    TableSchema dataSchema = getDataSchema();
    if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
        String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
        TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
        MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
        mappers.addAll(Arrays.asList(mapperList.getMappers()));
        dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
    }
    String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
    String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
    Params tfModelMapperParams = new Params();
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
    tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
    tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, reservedCols);
    tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
    if (null != modelData.getTfModelZipPath()) {
        tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
    } else {
        tfModelMapper.loadModel(modelData.getTfModelRows());
    }
    mappers.add(tfModelMapper);
    predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
    sortedLabels = modelData.getSortedLabels();
    isOutputLogits = meta.get(TFModelDataConverterUtils.IS_OUTPUT_LOGITS);
}
Also used : MapperChain(com.alibaba.alink.common.mapper.MapperChain) TableSchema(org.apache.flink.table.api.TableSchema) TFTableModelClassificationPredictParams(com.alibaba.alink.params.classification.TFTableModelClassificationPredictParams) TFTableModelPredictParams(com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelPredictModelMapper(com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper)

Example 2 with TFTableModelPredictModelMapper

use of com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper in project Alink by alibaba.

the class TFTableModelRegressionModelMapper method loadFromModelData.

protected void loadFromModelData(TFTableModelRegressionModelData modelData, TableSchema modelSchema) {
    Params meta = modelData.getMeta();
    String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
    // TypeInformation <?> tfOutputSignatureType = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_TYPE);
    TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
    String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
    TableSchema dataSchema = getDataSchema();
    if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
        String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
        TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
        MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
        mappers.addAll(Arrays.asList(mapperList.getMappers()));
        dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
    }
    String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
    String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
    String[] tfReservedCols = (String[]) ArrayUtils.addAll(reservedCols, new String[] { predCol });
    Params tfModelMapperParams = new Params();
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
    tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
    tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
    tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, tfReservedCols);
    tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
    if (null != modelData.getTfModelZipPath()) {
        tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
    } else {
        tfModelMapper.loadModel(modelData.getTfModelRows());
    }
    mappers.add(tfModelMapper);
    predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
}
Also used : MapperChain(com.alibaba.alink.common.mapper.MapperChain) TableSchema(org.apache.flink.table.api.TableSchema) TFTableModelClassificationPredictParams(com.alibaba.alink.params.classification.TFTableModelClassificationPredictParams) TFTableModelPredictParams(com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelPredictModelMapper(com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper)

Aggregations

MapperChain (com.alibaba.alink.common.mapper.MapperChain)2 TFTableModelPredictModelMapper (com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper)2 TFTableModelClassificationPredictParams (com.alibaba.alink.params.classification.TFTableModelClassificationPredictParams)2 TFTableModelPredictParams (com.alibaba.alink.params.tensorflow.savedmodel.TFTableModelPredictParams)2 Params (org.apache.flink.ml.api.misc.param.Params)2 TableSchema (org.apache.flink.table.api.TableSchema)2