use of com.alibaba.alink.common.io.directreader.DataBridge in project Alink by alibaba.
the class StreamingKMeansStreamOp method linkFrom.
/**
* Update model with stream in1, predict for stream in2
*/
@Override
public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... inputs) {
checkMinOpSize(1, inputs);
StreamOperator<?> in1 = inputs[0];
StreamOperator<?> in2 = inputs[0];
if (inputs.length > 1) {
in2 = inputs[1];
}
if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) {
this.setPredictionCol("cluster_id");
}
/**
* time interval for updating the model, in seconds
*/
final long timeInterval = getParams().get(TIME_INTERVAL);
final long halfLife = getParams().get(HALF_LIFE);
final double decayFactor = Math.pow(0.5, (double) timeInterval / (double) halfLife);
try {
DataStream<Row> trainingData = in1.getDataStream();
DataStream<Row> predictData = in2.getDataStream();
PredType predType = PredType.fromInputs(getParams());
OutputColsHelper outputColsHelper = null;
switch(predType) {
case PRED:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol() }, new TypeInformation[] { Types.LONG }, this.getReservedCols());
break;
}
case PRED_CLUS:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionClusterCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR }, this.getReservedCols());
break;
}
case PRED_DIST:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, Types.DOUBLE }, this.getReservedCols());
break;
}
case PRED_CLUS_DIST:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { this.getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR, Types.DOUBLE }, this.getReservedCols());
}
}
// for direct read
DataBridge modelDataBridge = DirectReader.collect(batchModel);
// incremental train on every window of data
DataStream<Tuple3<DenseVector[], int[], Long>> updateData = trainingData.flatMap(new CollectUpdateData(modelDataBridge, in1.getColNames(), timeInterval)).name("local_aggregate");
int taskNum = updateData.getParallelism();
DataStream<KMeansTrainModelData> streamModel = updateData.flatMap(new AllDataMerge(taskNum)).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(modelDataBridge, decayFactor)).name("update_model").setParallelism(1);
// predict
DataStream<Row> predictResult = predictData.connect(streamModel.broadcast()).flatMap(new PredictOp(modelDataBridge, in2.getColNames(), outputColsHelper, predType)).name("kmeans_prediction");
this.setOutput(predictResult, outputColsHelper.getResultSchema());
this.setSideOutputTables(outputModel(streamModel, getMLEnvironmentId()));
return this;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}
use of com.alibaba.alink.common.io.directreader.DataBridge in project Alink by alibaba.
the class ModelMapStreamOp method linkFrom.
@Override
public T linkFrom(StreamOperator<?>... inputs) {
checkMinOpSize(1, inputs);
StreamOperator<?> in = inputs[0];
TableSchema modelSchema = this.model.getSchema();
try {
DataBridge modelDataBridge = DirectReader.collect(model);
final DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
final ModelMapper mapper = this.mapperBuilder.apply(modelSchema, in.getSchema(), this.getParams());
DataStream<Row> resultRows;
DataStream<Row> modelStream = null;
TableSchema modelStreamSchema = null;
if (ModelStreamUtils.useModelStreamFile(getParams())) {
StreamOperator<?> modelStreamOp = new ModelStreamFileSourceStreamOp().setFilePath(getModelStreamFilePath()).setScanInterval(getModelStreamScanInterval()).setStartTime(getModelStreamStartTime()).setSchemaStr(CsvUtil.schema2SchemaStr(modelSchema)).setMLEnvironmentId(getMLEnvironmentId());
modelStreamSchema = modelStreamOp.getSchema();
modelStream = modelStreamOp.getDataStream();
}
if (inputs.length > 1) {
StreamOperator<?> localModelStreamOp = inputs[1];
if (modelStream == null) {
modelStreamSchema = localModelStreamOp.getSchema();
modelStream = localModelStreamOp.getDataStream();
} else {
localModelStreamOp = localModelStreamOp.select(modelStreamSchema.getFieldNames());
modelStream = modelStream.union(localModelStreamOp.getDataStream());
}
}
if (modelStream != null) {
resultRows = in.getDataStream().connect(ModelStreamUtils.broadcastStream(modelStream)).flatMap(new PredictProcess(modelSchema, in.getSchema(), getParams(), mapperBuilder, modelDataBridge, ModelStreamUtils.findTimestampColIndexWithAssertAndHint(modelStreamSchema), ModelStreamUtils.findCountColIndexWithAssertAndHint(modelStreamSchema)));
} else if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = in.getDataStream().map(new ModelMapperAdapter(mapper, modelSource));
} else {
resultRows = in.getDataStream().flatMap(new ModelMapperAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
}
TableSchema resultSchema = mapper.getOutputSchema();
this.setOutput(resultRows, resultSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
use of com.alibaba.alink.common.io.directreader.DataBridge in project Alink by alibaba.
the class BaseRecommStreamOp method linkFrom.
@Override
public T linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
TableSchema modelSchema = this.model.getSchema();
try {
DataBridge modelDataBridge = DirectReader.collect(model);
DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, modelSchema, in.getSchema(), this.getParams());
DataStream<Row> resultRows;
if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
resultRows = in.getDataStream().map(new RecommAdapter(mapper, modelSource));
} else {
resultRows = in.getDataStream().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
}
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
Aggregations