use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.
the class TFTableModelClassificationModelMapper method loadFromModelData.
protected void loadFromModelData(TFTableModelClassificationModelData modelData, TableSchema modelSchema) {
Params meta = modelData.getMeta();
String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
TableSchema dataSchema = getDataSchema();
if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
mappers.addAll(Arrays.asList(mapperList.getMappers()));
dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
}
String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
Params tfModelMapperParams = new Params();
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, reservedCols);
tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
if (null != modelData.getTfModelZipPath()) {
tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
} else {
tfModelMapper.loadModel(modelData.getTfModelRows());
}
mappers.add(tfModelMapper);
predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
sortedLabels = modelData.getSortedLabels();
isOutputLogits = meta.get(TFModelDataConverterUtils.IS_OUTPUT_LOGITS);
}
use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.
the class LocalPredictor method merge.
public void merge(LocalPredictor otherPredictor) {
this.mappers.addAll(otherPredictor.mappers);
this.mapperList = new MapperChain(this.mappers.toArray(new Mapper[0]));
}
use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.
the class ModelExporterUtils method loadMapperListFromStages.
// mapper not open.
public static MapperChain loadMapperListFromStages(List<Tuple3<PipelineStageBase<?>, TableSchema, List<Row>>> stages, TableSchema inputSchema) {
TableSchema outSchema = inputSchema;
List<Mapper> mappers = new ArrayList<>();
for (Tuple3<PipelineStageBase<?>, TableSchema, List<Row>> stageTuple3 : stages) {
Mapper mapper = createMapperFromStage(stageTuple3.f0, stageTuple3.f1, outSchema, stageTuple3.f2);
mappers.add(mapper);
outSchema = mapper.getOutputSchema();
}
return new MapperChain(mappers.toArray(new Mapper[0]));
}
use of com.alibaba.alink.common.mapper.MapperChain in project Alink by alibaba.
the class TFTableModelRegressionModelMapper method loadFromModelData.
protected void loadFromModelData(TFTableModelRegressionModelData modelData, TableSchema modelSchema) {
Params meta = modelData.getMeta();
String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
// TypeInformation <?> tfOutputSignatureType = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_TYPE);
TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
TableSchema dataSchema = getDataSchema();
if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
mappers.addAll(Arrays.asList(mapperList.getMappers()));
dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
}
String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
String[] tfReservedCols = (String[]) ArrayUtils.addAll(reservedCols, new String[] { predCol });
Params tfModelMapperParams = new Params();
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, tfReservedCols);
tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
if (null != modelData.getTfModelZipPath()) {
tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
} else {
tfModelMapper.loadModel(modelData.getTfModelRows());
}
mappers.add(tfModelMapper);
predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
}
Aggregations