use of com.alibaba.alink.common.model.DataBridgeModelSource in project Alink by alibaba.
the class ModelMapStreamOp method linkFrom.
@Override
public T linkFrom(StreamOperator<?>... inputs) {
checkMinOpSize(1, inputs);
StreamOperator<?> in = inputs[0];
TableSchema modelSchema = this.model.getSchema();
try {
DataBridge modelDataBridge = DirectReader.collect(model);
final DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
final ModelMapper mapper = this.mapperBuilder.apply(modelSchema, in.getSchema(), this.getParams());
DataStream<Row> resultRows;
DataStream<Row> modelStream = null;
TableSchema modelStreamSchema = null;
if (ModelStreamUtils.useModelStreamFile(getParams())) {
StreamOperator<?> modelStreamOp = new ModelStreamFileSourceStreamOp().setFilePath(getModelStreamFilePath()).setScanInterval(getModelStreamScanInterval()).setStartTime(getModelStreamStartTime()).setSchemaStr(CsvUtil.schema2SchemaStr(modelSchema)).setMLEnvironmentId(getMLEnvironmentId());
modelStreamSchema = modelStreamOp.getSchema();
modelStream = modelStreamOp.getDataStream();
}
if (inputs.length > 1) {
StreamOperator<?> localModelStreamOp = inputs[1];
if (modelStream == null) {
modelStreamSchema = localModelStreamOp.getSchema();
modelStream = localModelStreamOp.getDataStream();
} else {
localModelStreamOp = localModelStreamOp.select(modelStreamSchema.getFieldNames());
modelStream = modelStream.union(localModelStreamOp.getDataStream());
}
}
if (modelStream != null) {
resultRows = in.getDataStream().connect(ModelStreamUtils.broadcastStream(modelStream)).flatMap(new PredictProcess(modelSchema, in.getSchema(), getParams(), mapperBuilder, modelDataBridge, ModelStreamUtils.findTimestampColIndexWithAssertAndHint(modelStreamSchema), ModelStreamUtils.findCountColIndexWithAssertAndHint(modelStreamSchema)));
} else if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = in.getDataStream().map(new ModelMapperAdapter(mapper, modelSource));
} else {
resultRows = in.getDataStream().flatMap(new ModelMapperAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
}
TableSchema resultSchema = mapper.getOutputSchema();
this.setOutput(resultRows, resultSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
use of com.alibaba.alink.common.model.DataBridgeModelSource in project Alink by alibaba.
the class BaseRecommStreamOp method linkFrom.
@Override
public T linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
TableSchema modelSchema = this.model.getSchema();
try {
DataBridge modelDataBridge = DirectReader.collect(model);
DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, modelSchema, in.getSchema(), this.getParams());
DataStream<Row> resultRows;
if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = in.getDataStream().map(new RecommAdapter(mapper, modelSource));
} else {
resultRows = in.getDataStream().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
}
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
Aggregations