Search in sources :

Example 1 with OutputColsHelper

use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.

the class StreamingKMeansStreamOp method linkFrom.

/**
 * Update model with stream in1, predict for stream in2
 */
@Override
public StreamingKMeansStreamOp linkFrom(StreamOperator<?>... inputs) {
    checkMinOpSize(1, inputs);
    StreamOperator<?> in1 = inputs[0];
    StreamOperator<?> in2 = inputs[0];
    if (inputs.length > 1) {
        in2 = inputs[1];
    }
    if (!this.getParams().contains(HasPredictionCol.PREDICTION_COL)) {
        this.setPredictionCol("cluster_id");
    }
    /**
     * time interval for updating the model, in seconds
     */
    final long timeInterval = getParams().get(TIME_INTERVAL);
    final long halfLife = getParams().get(HALF_LIFE);
    final double decayFactor = Math.pow(0.5, (double) timeInterval / (double) halfLife);
    try {
        DataStream<Row> trainingData = in1.getDataStream();
        DataStream<Row> predictData = in2.getDataStream();
        PredType predType = PredType.fromInputs(getParams());
        OutputColsHelper outputColsHelper = null;
        switch(predType) {
            case PRED:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol() }, new TypeInformation[] { Types.LONG }, this.getReservedCols());
                    break;
                }
            case PRED_CLUS:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionClusterCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR }, this.getReservedCols());
                    break;
                }
            case PRED_DIST:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { getPredictionCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, Types.DOUBLE }, this.getReservedCols());
                    break;
                }
            case PRED_CLUS_DIST:
                {
                    outputColsHelper = new OutputColsHelper(in2.getSchema(), new String[] { this.getPredictionCol(), getPredictionClusterCol(), getPredictionDistanceCol() }, new TypeInformation[] { Types.LONG, VectorTypes.DENSE_VECTOR, Types.DOUBLE }, this.getReservedCols());
                }
        }
        // for direct read
        DataBridge modelDataBridge = DirectReader.collect(batchModel);
        // incremental train on every window of data
        DataStream<Tuple3<DenseVector[], int[], Long>> updateData = trainingData.flatMap(new CollectUpdateData(modelDataBridge, in1.getColNames(), timeInterval)).name("local_aggregate");
        int taskNum = updateData.getParallelism();
        DataStream<KMeansTrainModelData> streamModel = updateData.flatMap(new AllDataMerge(taskNum)).name("global_aggregate").setParallelism(1).map(new UpdateModelOp(modelDataBridge, decayFactor)).name("update_model").setParallelism(1);
        // predict
        DataStream<Row> predictResult = predictData.connect(streamModel.broadcast()).flatMap(new PredictOp(modelDataBridge, in2.getColNames(), outputColsHelper, predType)).name("kmeans_prediction");
        this.setOutput(predictResult, outputColsHelper.getResultSchema());
        this.setSideOutputTables(outputModel(streamModel, getMLEnvironmentId()));
        return this;
    } catch (Exception e) {
        e.printStackTrace();
        throw new RuntimeException(e.getMessage());
    }
}
Also used : KMeansTrainModelData(com.alibaba.alink.operator.common.clustering.kmeans.KMeansTrainModelData) TypeHint(org.apache.flink.api.common.typeinfo.TypeHint) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) DenseVector(com.alibaba.alink.common.linalg.DenseVector) DataBridge(com.alibaba.alink.common.io.directreader.DataBridge)

Example 2 with OutputColsHelper

use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.

the class KeywordsExtractionStreamOp method linkFrom.

@Override
public KeywordsExtractionStreamOp linkFrom(StreamOperator<?>... inputs) {
    StreamOperator<?> in = checkAndGetFirst(inputs);
    String selectedColName = this.getSelectedCol();
    int textColIndex = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), selectedColName);
    String outputColName = this.getOutputCol();
    if (null == outputColName) {
        outputColName = selectedColName;
    }
    OutputColsHelper outputColsHelper = new OutputColsHelper(in.getSchema(), new String[] { outputColName }, new TypeInformation[] { org.apache.flink.table.api.Types.STRING() }, in.getColNames());
    DataStream<Row> res = in.getDataStream().map(new KeywordsExtractionMap(this.getParams(), textColIndex, outputColsHelper));
    // Set the output into table.
    this.setOutput(res, outputColsHelper.getResultSchema());
    return this;
}
Also used : KeywordsExtractionMap(com.alibaba.alink.operator.common.nlp.KeywordsExtractionMap) Row(org.apache.flink.types.Row) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper)

Example 3 with OutputColsHelper

use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.

the class HugeIndexerStringPredictBatchOp method linkFrom.

@Override
public HugeIndexerStringPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
    Params params = super.getParams();
    BatchOperator model = inputs[0];
    BatchOperator data = inputs[1];
    String[] selectedColNames = params.get(HugeMultiStringIndexerPredictParams.SELECTED_COLS);
    String[] outputColNames = params.get(HugeMultiStringIndexerPredictParams.OUTPUT_COLS);
    if (outputColNames == null) {
        outputColNames = selectedColNames;
    }
    String[] keepColNames = params.get(HugeMultiStringIndexerPredictParams.RESERVED_COLS);
    TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
    TypeInformation[] inputColTypes = TableUtil.findColTypesWithAssert(data.getSchema(), selectedColNames);
    Arrays.fill(outputColTypes, Types.STRING);
    OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames, outputColTypes, keepColNames);
    final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
    final HandleInvalid handleInvalidStrategy = HandleInvalid.valueOf(params.get(HugeMultiStringIndexerPredictParams.HANDLE_INVALID).toString());
    DataSet<Tuple2<Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());
    DataSet<Tuple2<String, Long>> modelData = getModelData(model);
    // tuple: record id, column index, column array index, token
    DataSet<Tuple4<Long, Integer, Integer, Long>> flattened = dataWithId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple4<Long, Integer, Integer, Long>>() {

        private static final long serialVersionUID = -8382461068855755626L;

        @Override
        public void flatMap(Tuple2<Long, Row> value, Collector<Tuple4<Long, Integer, Integer, Long>> out) throws Exception {
            for (int i = 0; i < selectedColIdx.length; i++) {
                Object o = value.f1.getField(selectedColIdx[i]);
                if (null == o) {
                    out.collect(Tuple4.of(value.f0, i, 0, -1L));
                } else {
                    if (inputColTypes[i].isBasicType()) {
                        out.collect(Tuple4.of(value.f0, i, 0, (Long) o));
                    } else {
                        Long[] ids = (Long[]) o;
                        for (int j = 0; j < ids.length; j++) {
                            out.collect(Tuple4.of(value.f0, i, j, ids[j]));
                        }
                    }
                }
            }
        }
    }).name("flatten_pred_data").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.INT, Types.LONG));
    // record id, column index, token index
    DataSet<Tuple4<Long, Integer, Integer, String>> indexed = flattened.leftOuterJoin(modelData).where(3).equalTo(1).with(new JoinFunction<Tuple4<Long, Integer, Integer, Long>, Tuple2<String, Long>, Tuple4<Long, Integer, Integer, String>>() {

        private static final long serialVersionUID = 2270459281179536013L;

        @Override
        public Tuple4<Long, Integer, Integer, String> join(Tuple4<Long, Integer, Integer, Long> first, Tuple2<String, Long> second) throws Exception {
            if (second == null) {
                return Tuple4.of(first.f0, first.f1, first.f2, "notFound");
            } else {
                return Tuple4.of(first.f0, first.f1, first.f2, second.f0);
            }
        }
    }).name("map_index_to_token").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.INT, Types.STRING));
    // tuple: record id, prediction result
    DataSet<Tuple2<Long, Row>> aggregateResult = indexed.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple4<Long, Integer, Integer, String>, Tuple2<Long, Row>>() {

        private static final long serialVersionUID = -1581264399340055162L;

        @Override
        public void reduce(Iterable<Tuple4<Long, Integer, Integer, String>> values, Collector<Tuple2<Long, Row>> out) throws Exception {
            Long id = null;
            Row r = new Row(selectedColIdx.length);
            ArrayList<Tuple3<Integer, Integer, String>> list = new ArrayList<>();
            for (Tuple4<Long, Integer, Integer, String> v : values) {
                list.add(Tuple3.of(v.f1, v.f2, v.f3));
                id = v.f0;
            }
            list.sort(new Comparator<Tuple3<Integer, Integer, String>>() {

                @Override
                public int compare(Tuple3<Integer, Integer, String> o1, Tuple3<Integer, Integer, String> o2) {
                    if (o1.f0.equals(o2.f0)) {
                        return o1.f1.compareTo(o2.f1);
                    }
                    return o1.f0.compareTo(o2.f0);
                }
            });
            ArrayList<String> allFeatures = new ArrayList<>(list.size());
            for (Tuple3<Integer, Integer, String> v : list) {
                allFeatures.add(v.f2);
            }
            String[] originFeatures = new String[selectedColIdx.length];
            int startIndex = 0, endIndex = 0;
            int lastIndex = 0;
            for (int i = 0; i < list.size(); i++) {
                Tuple3<Integer, Integer, String> v = list.get(i);
                if (lastIndex != v.f0) {
                    originFeatures[lastIndex] = StringUtils.join(allFeatures.subList(startIndex, endIndex), ",");
                    lastIndex = v.f0;
                    startIndex = i;
                    endIndex = i;
                }
                endIndex += 1;
                if (i == list.size() - 1) {
                    originFeatures[lastIndex] = StringUtils.join(allFeatures.subList(startIndex, endIndex), ",");
                }
            }
            for (int i = 0; i < originFeatures.length; i++) {
                r.setField(i, originFeatures[i]);
            }
            out.collect(Tuple2.of(id, r));
        }
    }).name("aggregate_result").returns(new TupleTypeInfo<>(Types.LONG, new RowTypeInfo(outputColTypes)));
    DataSet<Row> output = dataWithId.join(aggregateResult).where(0).equalTo(0).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() {

        private static final long serialVersionUID = 3724539437313089427L;

        @Override
        public Row join(Tuple2<Long, Row> first, Tuple2<Long, Row> second) throws Exception {
            return outputColsHelper.getResultRow(first.f1, second.f1);
        }
    }).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));
    this.setOutput(output, outputColsHelper.getResultSchema());
    return this;
}
Also used : ArrayList(java.util.ArrayList) RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Comparator(java.util.Comparator) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) HugeMultiStringIndexerPredictParams(com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Tuple4(org.apache.flink.api.java.tuple.Tuple4) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row)

Example 4 with OutputColsHelper

use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.

the class HugeLookupBatchOp method linkFrom.

@Override
public HugeLookupBatchOp linkFrom(BatchOperator<?>... inputs) {
    checkOpSize(2, inputs);
    BatchOperator<?> model = inputs[0];
    BatchOperator<?> data = inputs[1];
    final String[] mapKeyColNames = getMapKeyCols();
    final String[] mapValueColNames = getMapValueCols();
    final String[] selectedColNames = getSelectedCols();
    final String[] reservedColNames = getReservedCols();
    String[] outputColNames = getOutputCols();
    TableSchema modelSchema = model.getSchema();
    TableSchema dataSchema = data.getSchema();
    if (modelSchema.getFieldNames().length != 2 && (mapKeyColNames == null || mapValueColNames == null)) {
        throw new RuntimeException("LookUp err : mapKeyCols and mapValueCols should set in parameters.");
    }
    final int[] selectedColIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema, selectedColNames);
    ;
    final int[] mapKeyColIndices = (mapKeyColNames != null) ? TableUtil.findColIndicesWithAssertAndHint(modelSchema, mapKeyColNames) : new int[] { 0 };
    final int[] mapValueColIndices = (mapValueColNames != null) ? TableUtil.findColIndicesWithAssertAndHint(modelSchema, mapValueColNames) : new int[] { 1 };
    for (int i = 0; i < selectedColNames.length; ++i) {
        if (mapKeyColNames != null && mapValueColNames != null) {
            if (TableUtil.findColTypeWithAssertAndHint(dataSchema, selectedColNames[i]) != TableUtil.findColTypeWithAssertAndHint(modelSchema, mapKeyColNames[i])) {
                throw new IllegalArgumentException("Data types are not match. selected column type is " + TableUtil.findColTypeWithAssertAndHint(dataSchema, selectedColNames[i]) + " , and the map key column type is " + TableUtil.findColTypeWithAssertAndHint(modelSchema, mapKeyColNames[i]));
            }
        }
    }
    if (null == outputColNames) {
        outputColNames = mapValueColNames;
    }
    final TypeInformation<?>[] outputColTypes = (mapValueColNames == null) ? TableUtil.findColTypesWithAssertAndHint(modelSchema, new String[] { modelSchema.getFieldNames()[1] }) : TableUtil.findColTypesWithAssertAndHint(modelSchema, mapValueColNames);
    final OutputColsHelper predResultColsHelper = new OutputColsHelper(dataSchema, outputColNames, outputColTypes, reservedColNames);
    DataSet<Row> result = data.getDataSet().leftOuterJoin(model.getDataSet(), JoinHint.REPARTITION_SORT_MERGE).where(selectedColIndices).equalTo(mapKeyColIndices).with(new JoinFunction<Row, Row, Row>() {

        @Override
        public Row join(Row first, Row second) {
            Row result = new Row(mapValueColIndices.length);
            if (second != null) {
                for (int i = 0; i < mapValueColIndices.length; ++i) {
                    result.setField(i, second.getField(mapValueColIndices[i]));
                }
            }
            return predResultColsHelper.getResultRow(first, result);
        }
    });
    setOutput(result, predResultColsHelper.getResultSchema());
    return this;
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) JoinHint(org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Row(org.apache.flink.types.Row) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper)

Example 5 with OutputColsHelper

use of com.alibaba.alink.common.utils.OutputColsHelper in project Alink by alibaba.

the class HugeMultiIndexerStringPredictBatchOp method linkFrom.

@Override
public HugeMultiIndexerStringPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
    Params params = super.getParams();
    BatchOperator model = inputs[0];
    BatchOperator data = inputs[1];
    String[] selectedColNames = params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
    String[] outputColNames = params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
    if (outputColNames == null) {
        outputColNames = selectedColNames;
    }
    String[] keepColNames = params.get(StringIndexerPredictParams.RESERVED_COLS);
    TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
    Arrays.fill(outputColTypes, Types.STRING);
    OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames, outputColTypes, keepColNames);
    final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
    final HandleInvalid handleInvalidStrategy = HandleInvalid.valueOf(params.get(StringIndexerPredictParams.HANDLE_INVALID).toString());
    DataSet<Tuple2<Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());
    DataSet<String> modelMeta = getModelMeta(model);
    DataSet<Tuple3<Integer, String, Long>> modelData = getModelData(model, modelMeta, selectedColNames);
    // tuple: record id, column index, token
    DataSet<Tuple3<Long, Integer, Long>> flattened = dataWithId.flatMap(new RichFlatMapFunction<Tuple2<Long, Row>, Tuple3<Long, Integer, Long>>() {

        private static final long serialVersionUID = 7795878509849151894L;

        @Override
        public void flatMap(Tuple2<Long, Row> value, Collector<Tuple3<Long, Integer, Long>> out) throws Exception {
            for (int i = 0; i < selectedColIdx.length; i++) {
                Object o = value.f1.getField(selectedColIdx[i]);
                if (o != null) {
                    out.collect(Tuple3.of(value.f0, i, (Long) o));
                } else {
                    out.collect(Tuple3.of(value.f0, i, -1L));
                }
            }
        }
    }).name("flatten_pred_data").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.LONG));
    // record id, column index, token index
    DataSet<Tuple3<Long, Integer, String>> indexed = flattened.leftOuterJoin(modelData).where(1, 2).equalTo(0, 2).with(new JoinFunction<Tuple3<Long, Integer, Long>, Tuple3<Integer, String, Long>, Tuple3<Long, Integer, String>>() {

        private static final long serialVersionUID = -3177975102816197011L;

        @Override
        public Tuple3<Long, Integer, String> join(Tuple3<Long, Integer, Long> first, Tuple3<Integer, String, Long> second) throws Exception {
            if (second == null) {
                return Tuple3.of(first.f0, first.f1, "null");
            } else {
                return Tuple3.of(first.f0, first.f1, second.f1);
            }
        }
    }).name("map_index_to_token").returns(new TupleTypeInfo<>(Types.LONG, Types.INT, Types.STRING));
    // tuple: record id, prediction result
    DataSet<Tuple2<Long, Row>> aggregateResult = indexed.groupBy(0).reduceGroup(new RichGroupReduceFunction<Tuple3<Long, Integer, String>, Tuple2<Long, Row>>() {

        private static final long serialVersionUID = 2318140138585310686L;

        @Override
        public void reduce(Iterable<Tuple3<Long, Integer, String>> values, Collector<Tuple2<Long, Row>> out) throws Exception {
            Long id = null;
            Row r = new Row(selectedColIdx.length);
            for (Tuple3<Long, Integer, String> v : values) {
                r.setField(v.f1, v.f2);
                id = v.f0;
            }
            out.collect(Tuple2.of(id, r));
        }
    }).name("aggregate_result").returns(new TupleTypeInfo<>(Types.LONG, new RowTypeInfo(outputColTypes)));
    DataSet<Row> output = dataWithId.join(aggregateResult).where(0).equalTo(0).with(new JoinFunction<Tuple2<Long, Row>, Tuple2<Long, Row>, Row>() {

        private static final long serialVersionUID = 3724539437313089427L;

        @Override
        public Row join(Tuple2<Long, Row> first, Tuple2<Long, Row> second) throws Exception {
            return outputColsHelper.getResultRow(first.f1, second.f1);
        }
    }).name("merge_result").returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));
    this.setOutput(output, outputColsHelper.getResultSchema());
    return this;
}
Also used : RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) StringIndexerPredictParams(com.alibaba.alink.params.dataproc.StringIndexerPredictParams) HugeMultiStringIndexerPredictParams(com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams) Params(org.apache.flink.ml.api.misc.param.Params) MultiStringIndexerPredictParams(com.alibaba.alink.params.dataproc.MultiStringIndexerPredictParams) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row)

Aggregations

OutputColsHelper (com.alibaba.alink.common.utils.OutputColsHelper)9 Row (org.apache.flink.types.Row)8 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)5 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)5 Tuple3 (org.apache.flink.api.java.tuple.Tuple3)5 Params (org.apache.flink.ml.api.misc.param.Params)5 HugeMultiStringIndexerPredictParams (com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams)4 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)4 RowTypeInfo (org.apache.flink.api.java.typeutils.RowTypeInfo)4 MultiStringIndexerPredictParams (com.alibaba.alink.params.dataproc.MultiStringIndexerPredictParams)2 StringIndexerPredictParams (com.alibaba.alink.params.dataproc.StringIndexerPredictParams)2 ArrayList (java.util.ArrayList)2 Comparator (java.util.Comparator)2 MapPartitionFunction (org.apache.flink.api.common.functions.MapPartitionFunction)2 RichGroupReduceFunction (org.apache.flink.api.common.functions.RichGroupReduceFunction)2 Configuration (org.apache.flink.configuration.Configuration)2 Collector (org.apache.flink.util.Collector)2 DataBridge (com.alibaba.alink.common.io.directreader.DataBridge)1 DenseVector (com.alibaba.alink.common.linalg.DenseVector)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1