use of com.alibaba.alink.operator.common.recommendation.RecommAdapterMT in project Alink by alibaba.
the class BaseRecommStreamOp method linkFrom.
@Override
public T linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
TableSchema modelSchema = this.model.getSchema();
try {
DataBridge modelDataBridge = DirectReader.collect(model);
DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, modelSchema, in.getSchema(), this.getParams());
DataStream<Row> resultRows;
if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = in.getDataStream().map(new RecommAdapter(mapper, modelSource));
} else {
resultRows = in.getDataStream().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
}
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
use of com.alibaba.alink.operator.common.recommendation.RecommAdapterMT in project Alink by alibaba.
the class BaseRecommBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
checkOpSize(2, inputs);
try {
BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
DataSet<Row> resultRows;
if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = inputs[1].getDataSet().map(new RecommAdapter(mapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
} else {
resultRows = inputs[1].getDataSet().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
}
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
Aggregations