Search in sources :

Example 1 with KMeansTrainModelData

use of com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData in project Alink by alibaba.

the class StreamingKMeansStreamOp method linkFrom.

/**
 * Update model with stream in1, predict for stream in2
 */
@Override
public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... inputs) {
    checkMinOpSize(1, inputs);
    StreamOperator<?> in1 = inputs[0];
    StreamOperator<?> in2 = inputs[0];
    if (inputs.length > 1) {
        in2 = inputs[1];
    }
    if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) {
        this.setPredictionCol("cluster_id");
    }
    /**
     * time interval for updating the model, in seconds
     */
    final long timeInterval = getParams().get(TIME_INTERVAL);
    final long halfLife = getParams().get(HALF_LIFE);
    final double decayFactor = Math.pow(0.5, (double) timeInterval / (double) halfLife);
    try {
        DataStream<Row> trainingData = in1.getDataStream();
        DataStream<Row> predictData = in2.getDataStream();
        PredType predType = PredType.fromInputs(getParams());
        OutputColsHelper outputColsHelper = null;
        switch(predType) {
            case PRED:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol() }, new TypeInformation[] { Types.LONG }, this.getReservedCols());
                    break;
                }
            case PRED_CLUS:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionClusterCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR }, this.getReservedCols());
                    break;
                }
            case PRED_DIST:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, Types.DOUBLE }, this.getReservedCols());
                    break;
                }
            case PRED_CLUS_DIST:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { this.getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR, Types.DOUBLE }, this.getReservedCols());
                }
        }
        // for direct read
        DataBridge modelDataBridge = DirectReader.collect(batchModel);
        // incremental train on every window of data
        DataStream<Tuple3<DenseVector[], int[], Long>> updateData = trainingData.flatMap(new CollectUpdateData(modelDataBridge, in1.getColNames(), timeInterval)).name("local_aggregate");
        int taskNum = updateData.getParallelism();
        DataStream<KMeansTrainModelData> streamModel = updateData.flatMap(new AllDataMerge(taskNum)).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(modelDataBridge, decayFactor)).name("update_model").setParallelism(1);
        // predict
        DataStream<Row> predictResult = predictData.connect(streamModel.broadcast()).flatMap(new PredictOp(modelDataBridge, in2.getColNames(), outputColsHelper, predType)).name("kmeans_prediction");
        this.setOutput(predictResult, outputColsHelper.getResultSchema());
        this.setSideOutputTables(outputModel(streamModel, getMLEnvironmentId()));
        return this;
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e.getMessage());
    }
}
Also used : KMeansTrainModelData(com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData) TypeHint(org.apache.flink.api.common.typeinfo.TypeHint) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) DenseVector(com.alibaba.alink.common.linalg.DenseVector) DataBridge(com.alibaba.alink.common.io.directreader.DataBridge)

Aggregations

DataBridge (com.alibaba.alink.common.io.directreader.DataBridge)1 DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 OutputColsHelper (com.alibaba.alink.common.utils.OutputColsHelper)1 KMeansTrainModelData (com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData)1 TypeHint (org.apache.flink.api.common.typeinfo.TypeHint)1 Tuple3 (org.apache.flink.api.java.tuple.Tuple3)1 Row (org.apache.flink.types.Row)1