Search in sources :

Example 1 with DataBridge

use of com.alibaba.alink.common.io.directreader.DataBridge in project Alink by alibaba.

the class StreamingKMeansStreamOp method linkFrom.

/**
 * Update model with stream in1, predict for stream in2
 */
@Override
public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... inputs) {
    checkMinOpSize(1, inputs);
    StreamOperator<?> in1 = inputs[0];
    StreamOperator<?> in2 = inputs[0];
    if (inputs.length > 1) {
        in2 = inputs[1];
    }
    if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) {
        this.setPredictionCol("cluster_id");
    }
    /**
     * time interval for updating the model, in seconds
     */
    final long timeInterval = getParams().get(TIME_INTERVAL);
    final long halfLife = getParams().get(HALF_LIFE);
    final double decayFactor = Math.pow(0.5, (double) timeInterval / (double) halfLife);
    try {
        DataStream<Row> trainingData = in1.getDataStream();
        DataStream<Row> predictData = in2.getDataStream();
        PredType predType = PredType.fromInputs(getParams());
        OutputColsHelper outputColsHelper = null;
        switch(predType) {
            case PRED:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol() }, new TypeInformation[] { Types.LONG }, this.getReservedCols());
                    break;
                }
            case PRED_CLUS:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionClusterCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR }, this.getReservedCols());
                    break;
                }
            case PRED_DIST:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, Types.DOUBLE }, this.getReservedCols());
                    break;
                }
            case PRED_CLUS_DIST:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { this.getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR, Types.DOUBLE }, this.getReservedCols());
                }
        }
        // for direct read
        DataBridge modelDataBridge = DirectReader.collect(batchModel);
        // incremental train on every window of data
        DataStream<Tuple3<DenseVector[], int[], Long>> updateData = trainingData.flatMap(new CollectUpdateData(modelDataBridge, in1.getColNames(), timeInterval)).name("local_aggregate");
        int taskNum = updateData.getParallelism();
        DataStream<KMeansTrainModelData> streamModel = updateData.flatMap(new AllDataMerge(taskNum)).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(modelDataBridge, decayFactor)).name("update_model").setParallelism(1);
        // predict
        DataStream<Row> predictResult = predictData.connect(streamModel.broadcast()).flatMap(new PredictOp(modelDataBridge, in2.getColNames(), outputColsHelper, predType)).name("kmeans_prediction");
        this.setOutput(predictResult, outputColsHelper.getResultSchema());
        this.setSideOutputTables(outputModel(streamModel, getMLEnvironmentId()));
        return this;
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e.getMessage());
    }
}
Also used : KMeansTrainModelData(com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData) TypeHint(org.apache.flink.api.common.typeinfo.TypeHint) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) DenseVector(com.alibaba.alink.common.linalg.DenseVector) DataBridge(com.alibaba.alink.common.io.directreader.DataBridge)

Example 2 with DataBridge

use of com.alibaba.alink.common.io.directreader.DataBridge in project Alink by alibaba.

the class ModelMapStreamOp method linkFrom.

@Override
public T linkFrom(StreamOperator<?>... inputs) {
    checkMinOpSize(1, inputs);
    StreamOperator<?> in = inputs[0];
    TableSchema modelSchema = this.model.getSchema();
    try {
        DataBridge modelDataBridge = DirectReader.collect(model);
        final DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
        final ModelMapper mapper = this.mapperBuilder.apply(modelSchema, in.getSchema(), this.getParams());
        DataStream<Row> resultRows;
        DataStream<Row> modelStream = null;
        TableSchema modelStreamSchema = null;
        if (ModelStreamUtils.useModelStreamFile(getParams())) {
            StreamOperator<?> modelStreamOp = new ModelStreamFileSourceStreamOp().setFilePath(getModelStreamFilePath()).setScanInterval(getModelStreamScanInterval()).setStartTime(getModelStreamStartTime()).setSchemaStr(CsvUtil.schema2SchemaStr(modelSchema)).setMLEnvironmentId(getMLEnvironmentId());
            modelStreamSchema = modelStreamOp.getSchema();
            modelStream = modelStreamOp.getDataStream();
        }
        if (inputs.length > 1) {
            StreamOperator<?> localModelStreamOp = inputs[1];
            if (modelStream == null) {
                modelStreamSchema = localModelStreamOp.getSchema();
                modelStream = localModelStreamOp.getDataStream();
            } else {
                localModelStreamOp = localModelStreamOp.select(modelStreamSchema.getFieldNames());
                modelStream = modelStream.union(localModelStreamOp.getDataStream());
            }
        }
        if (modelStream != null) {
            resultRows = in.getDataStream().connect(ModelStreamUtils.broadcastStream(modelStream)).flatMap(new PredictProcess(modelSchema, in.getSchema(), getParams(), mapperBuilder, modelDataBridge, ModelStreamUtils.findTimestampColIndexWithAssertAndHint(modelStreamSchema), ModelStreamUtils.findCountColIndexWithAssertAndHint(modelStreamSchema)));
        } else if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
            resultRows = in.getDataStream().map(new ModelMapperAdapter(mapper, modelSource));
        } else {
            resultRows = in.getDataStream().flatMap(new ModelMapperAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
        }
        TableSchema resultSchema = mapper.getOutputSchema();
        this.setOutput(resultRows, resultSchema);
        return (T) this;
    } catch (Exception ex) {
        throw new RuntimeException(ex);
    }
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) PredictProcess(com.alibaba.alink.operator.common.stream.model.PredictProcess) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) ModelMapperAdapterMT(com.alibaba.alink.common.mapper.ModelMapperAdapterMT) ModelMapperAdapterMT(com.alibaba.alink.common.mapper.ModelMapperAdapterMT) DataBridgeModelSource(com.alibaba.alink.common.model.DataBridgeModelSource) Row(org.apache.flink.types.Row) ModelStreamFileSourceStreamOp(com.alibaba.alink.operator.stream.source.ModelStreamFileSourceStreamOp) ModelMapperAdapter(com.alibaba.alink.common.mapper.ModelMapperAdapter) DataBridge(com.alibaba.alink.common.io.directreader.DataBridge)

Example 3 with DataBridge

use of com.alibaba.alink.common.io.directreader.DataBridge in project Alink by alibaba.

the class BaseRecommStreamOp method linkFrom.

@Override
public T linkFrom(StreamOperator<?>... inputs) {
    StreamOperator<?> in = checkAndGetFirst(inputs);
    TableSchema modelSchema = this.model.getSchema();
    try {
        DataBridge modelDataBridge = DirectReader.collect(model);
        DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
        RecommMapper mapper = new RecommMapper(this.recommKernelBuilder, this.recommType, modelSchema, in.getSchema(), this.getParams());
        DataStream<Row> resultRows;
        if (getParams().get(ModelMapperParams.NUM_THREADS) <= 1) {
            resultRows = in.getDataStream().map(new RecommAdapter(mapper, modelSource));
        } else {
            resultRows = in.getDataStream().flatMap(new RecommAdapterMT(mapper, modelSource, getParams().get(ModelMapperParams.NUM_THREADS)));
        }
        TableSchema outputSchema = mapper.getOutputSchema();
        this.setOutput(resultRows, outputSchema);
        return (T) this;
    } catch (Exception ex) {
        throw new RuntimeException(ex);
    }
}
Also used : RecommAdapter(com.alibaba.alink.operator.common.recommendation.RecommAdapter) TableSchema(org.apache.flink.table.api.TableSchema) RecommAdapterMT(com.alibaba.alink.operator.common.recommendation.RecommAdapterMT) RecommAdapterMT(com.alibaba.alink.operator.common.recommendation.RecommAdapterMT) DataBridgeModelSource(com.alibaba.alink.common.model.DataBridgeModelSource) Row(org.apache.flink.types.Row) RecommMapper(com.alibaba.alink.operator.common.recommendation.RecommMapper) DataBridge(com.alibaba.alink.common.io.directreader.DataBridge)

Aggregations

DataBridge (com.alibaba.alink.common.io.directreader.DataBridge)3 Row (org.apache.flink.types.Row)3 DataBridgeModelSource (com.alibaba.alink.common.model.DataBridgeModelSource)2 TableSchema (org.apache.flink.table.api.TableSchema)2 DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)1 ModelMapperAdapter (com.alibaba.alink.common.mapper.ModelMapperAdapter)1 ModelMapperAdapterMT (com.alibaba.alink.common.mapper.ModelMapperAdapterMT)1 OutputColsHelper (com.alibaba.alink.common.utils.OutputColsHelper)1 KMeansTrainModelData (com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData)1 RecommAdapter (com.alibaba.alink.operator.common.recommendation.RecommAdapter)1 RecommAdapterMT (com.alibaba.alink.operator.common.recommendation.RecommAdapterMT)1 RecommMapper (com.alibaba.alink.operator.common.recommendation.RecommMapper)1 PredictProcess (com.alibaba.alink.operator.common.stream.model.PredictProcess)1 ModelStreamFileSourceStreamOp (com.alibaba.alink.operator.stream.source.ModelStreamFileSourceStreamOp)1 TypeHint (org.apache.flink.api.common.typeinfo.TypeHint)1 Tuple3 (org.apache.flink.api.java.tuple.Tuple3)1