use of com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData in project Alink by alibaba.
the class StreamingKMeansStreamOp method linkFrom.
/**
* Update model with stream in1, predict for stream in2
*/
@Override
public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... inputs) {
checkMinOpSize(1, inputs);
StreamOperator<?> in1 = inputs[0];
StreamOperator<?> in2 = inputs[0];
if (inputs.length > 1) {
in2 = inputs[1];
}
if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) {
this.setPredictionCol("cluster_id");
}
/**
* time interval for updating the model, in seconds
*/
final long timeInterval = getParams().get(TIME_INTERVAL);
final long halfLife = getParams().get(HALF_LIFE);
final double decayFactor = Math.pow(0.5, (double) timeInterval / (double) halfLife);
try {
DataStream<Row> trainingData = in1.getDataStream();
DataStream<Row> predictData = in2.getDataStream();
PredType predType = PredType.fromInputs(getParams());
OutputColsHelper outputColsHelper = null;
switch(predType) {
case PRED:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol() }, new TypeInformation[] { Types.LONG }, this.getReservedCols());
break;
}
case PRED_CLUS:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionClusterCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR }, this.getReservedCols());
break;
}
case PRED_DIST:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, Types.DOUBLE }, this.getReservedCols());
break;
}
case PRED_CLUS_DIST:
{
outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { this.getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR, Types.DOUBLE }, this.getReservedCols());
}
}
// for direct read
DataBridge modelDataBridge = DirectReader.collect(batchModel);
// incremental train on every window of data
DataStream<Tuple3<DenseVector[], int[], Long>> updateData = trainingData.flatMap(new CollectUpdateData(modelDataBridge, in1.getColNames(), timeInterval)).name("local_aggregate");
int taskNum = updateData.getParallelism();
DataStream<KMeansTrainModelData> streamModel = updateData.flatMap(new AllDataMerge(taskNum)).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(modelDataBridge, decayFactor)).name("update_model").setParallelism(1);
// predict
DataStream<Row> predictResult = predictData.connect(streamModel.broadcast()).flatMap(new PredictOp(modelDataBridge, in2.getColNames(), outputColsHelper, predType)).name("kmeans_prediction");
this.setOutput(predictResult, outputColsHelper.getResultSchema());
this.setSideOutputTables(outputModel(streamModel, getMLEnvironmentId()));
return this;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e.getMessage());
}
}
Aggregations