use of com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper in project Alink by alibaba.
the class TFTableModelClassificationModelMapper method loadFromModelData.
protected void loadFromModelData(TFTableModelClassificationModelData modelData, TableSchema modelSchema) {
Params meta = modelData.getMeta();
String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
TableSchema dataSchema = getDataSchema();
if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
mappers.addAll(Arrays.asList(mapperList.getMappers()));
dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
}
String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
Params tfModelMapperParams = new Params();
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, reservedCols);
tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
if (null != modelData.getTfModelZipPath()) {
tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
} else {
tfModelMapper.loadModel(modelData.getTfModelRows());
}
mappers.add(tfModelMapper);
predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
sortedLabels = modelData.getSortedLabels();
isOutputLogits = meta.get(TFModelDataConverterUtils.IS_OUTPUT_LOGITS);
}
use of com.alibaba.alink.operator.common.tensorflow.TFTableModelPredictModelMapper in project Alink by alibaba.
the class TFTableModelRegressionModelMapper method loadFromModelData.
protected void loadFromModelData(TFTableModelRegressionModelData modelData, TableSchema modelSchema) {
Params meta = modelData.getMeta();
String tfOutputSignatureDef = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_DEF);
// TypeInformation <?> tfOutputSignatureType = meta.get(TFModelDataConverterUtils.TF_OUTPUT_SIGNATURE_TYPE);
TypeInformation<?> tfOutputSignatureType = TensorTypes.FLOAT_TENSOR;
String[] reservedCols = null == params.get(HasReservedColsDefaultAsNull.RESERVED_COLS) ? getDataSchema().getFieldNames() : params.get(HasReservedColsDefaultAsNull.RESERVED_COLS);
TableSchema dataSchema = getDataSchema();
if (CollectionUtils.isNotEmpty(modelData.getPreprocessPipelineModelRows())) {
String preprocessPipelineModelSchemaStr = modelData.getPreprocessPipelineModelSchemaStr();
TableSchema pipelineModelSchema = CsvUtil.schemaStr2Schema(preprocessPipelineModelSchemaStr);
MapperChain mapperList = ModelExporterUtils.loadMapperListFromStages(modelData.getPreprocessPipelineModelRows(), pipelineModelSchema, dataSchema);
mappers.addAll(Arrays.asList(mapperList.getMappers()));
dataSchema = mappers.get(mappers.size() - 1).getOutputSchema();
}
String[] tfInputCols = meta.get(TFModelDataConverterUtils.TF_INPUT_COLS);
String predCol = params.get(TFTableModelClassificationPredictParams.PREDICTION_COL);
String[] tfReservedCols = (String[]) ArrayUtils.addAll(reservedCols, new String[] { predCol });
Params tfModelMapperParams = new Params();
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SIGNATURE_DEFS, new String[] { tfOutputSignatureDef });
tfModelMapperParams.set(TFTableModelPredictParams.OUTPUT_SCHEMA_STR, CsvUtil.schema2SchemaStr(TableSchema.builder().field(predCol, tfOutputSignatureType).build()));
tfModelMapperParams.set(TFTableModelPredictParams.SELECTED_COLS, tfInputCols);
tfModelMapperParams.set(TFTableModelPredictParams.RESERVED_COLS, tfReservedCols);
tfModelMapper = new TFTableModelPredictModelMapper(modelSchema, dataSchema, tfModelMapperParams, factory);
if (null != modelData.getTfModelZipPath()) {
tfModelMapper.loadModelFromZipFile(modelData.getTfModelZipPath());
} else {
tfModelMapper.loadModel(modelData.getTfModelRows());
}
mappers.add(tfModelMapper);
predColId = TableUtil.findColIndex(tfModelMapper.getOutputSchema(), predCol);
}
Aggregations