use of com.alibaba.alink.common.model.BroadcastVariableModelSource in project Alink by alibaba.
the class FlatModelMapBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
checkOpSize(2, inputs);
try {
BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
FlatModelMapper flatMapper = this.mapperBuilder.apply(inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
DataSet<Row> resultRows;
if (flatMapper instanceof IterableModelLoader) {
final long handler = IterTaskObjKeeper.getNewHandle();
DataSet<Row> modelRows = inputs[0].getDataSet();
DataSet<Row> distributedModelRows = modelRows.flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() {
private static final long serialVersionUID = 3544759002096859673L;
int numTask;
@Override
public void open(Configuration parameters) {
numTask = getRuntimeContext().getNumberOfParallelSubtasks();
}
@Override
public void flatMap(Row value, Collector<Tuple2<Integer, Row>> out) {
for (int i = 0; i < numTask; ++i) {
out.collect(Tuple2.of(i, value));
}
}
}).returns(new TupleTypeInfo<>(Types.INT, modelRows.getType())).partitionCustom(new Partitioner<Integer>() {
private static final long serialVersionUID = -2924355974935165844L;
@Override
public int partition(Integer key, int numPartitions) {
return key;
}
}, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() {
private static final long serialVersionUID = 8884296007768771379L;
@Override
public Row map(Tuple2<Integer, Row> value) throws Exception {
return value.f1;
}
}).returns(modelRows.getType());
DataSet<Integer> barrier = distributedModelRows.mapPartition(new RichMapPartitionFunction<Row, Integer>() {
private static final long serialVersionUID = 2358845952757630826L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Integer> out) {
int taskId = getRuntimeContext().getIndexOfThisSubtask();
((IterableModelLoader) flatMapper).loadIterableModel(values);
IterTaskObjKeeper.put(handler, taskId, flatMapper);
}
});
resultRows = inputs[1].getDataSet().flatMap(new IterableModelLoaderFlatModelMapperAdapter(handler)).withBroadcastSet(barrier, "barrier");
} else {
DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
resultRows = inputs[1].getDataSet().flatMap(new FlatModelMapperAdapter(flatMapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
}
TableSchema outputSchema = flatMapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
use of com.alibaba.alink.common.model.BroadcastVariableModelSource in project Alink by alibaba.
the class ModelMapBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
checkOpSize(2, inputs);
try {
final ModelMapper mapper = this.mapperBuilder.apply(inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
DataSet<Row> resultRows = null;
if (mapper instanceof IterableModelLoader) {
final long handler = IterTaskObjKeeper.getNewHandle();
DataSet<Row> modelRows = inputs[0].getDataSet();
DataSet<Row> distributedModelRows = modelRows.flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() {
private static final long serialVersionUID = 3544759002096859673L;
int numTask;
@Override
public void open(Configuration parameters) {
numTask = getRuntimeContext().getNumberOfParallelSubtasks();
}
@Override
public void flatMap(Row value, Collector<Tuple2<Integer, Row>> out) {
for (int i = 0; i < numTask; ++i) {
out.collect(Tuple2.of(i, value));
}
}
}).returns(new TupleTypeInfo<>(Types.INT, modelRows.getType())).partitionCustom(new Partitioner<Integer>() {
private static final long serialVersionUID = -2924355974935165844L;
@Override
public int partition(Integer key, int numPartitions) {
return key;
}
}, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() {
private static final long serialVersionUID = 8884296007768771379L;
@Override
public Row map(Tuple2<Integer, Row> value) throws Exception {
return value.f1;
}
}).returns(modelRows.getType());
DataSet<Integer> barrier = distributedModelRows.mapPartition(new RichMapPartitionFunction<Row, Integer>() {
private static final long serialVersionUID = 2358845952757630826L;
@Override
public void mapPartition(Iterable<Row> values, Collector<Integer> out) {
int taskId = getRuntimeContext().getIndexOfThisSubtask();
((IterableModelLoader) mapper).loadIterableModel(values);
IterTaskObjKeeper.put(handler, taskId, mapper);
}
});
if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = inputs[1].getDataSet().map(new IterableModelLoaderModelMapperAdapter(handler)).withBroadcastSet(barrier, "barrier");
} else {
resultRows = inputs[1].getDataSet().flatMap(new IterableModelLoaderModelMapperAdapterMT(handler, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(barrier, "barrier");
}
} else {
final BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
if (ModelStreamUtils.useModelStreamFile(getParams())) {
resultRows = inputs[1].getDataSet().map(new RichMapFunction<Row, Row>() {
ModelStreamModelMapperAdapter modelStreamModelMapper;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List<Row> modelRows = modelSource.getModelRows(getRuntimeContext());
mapper.loadModel(modelRows);
mapper.open();
modelStreamModelMapper = new ModelStreamModelMapperAdapter(mapper);
}
@Override
public Row map(Row value) throws Exception {
return modelStreamModelMapper.map(value);
}
}).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
} else if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = inputs[1].getDataSet().map(new ModelMapperAdapter(mapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
} else {
resultRows = inputs[1].getDataSet().flatMap(new ModelMapperAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
}
}
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
use of com.alibaba.alink.common.model.BroadcastVariableModelSource in project Alink by alibaba.
the class BaseRecommBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
checkOpSize(2, inputs);
try {
BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, inputs[0].getSchema(), inputs[1].getSchema(), this.getParams());
DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
DataSet<Row> resultRows;
if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = inputs[1].getDataSet().map(new RecommAdapter(mapper, modelSource)).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
} else {
resultRows = inputs[1].getDataSet().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS))).withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
}
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
Aggregations