Search in sources :

Example 1 with Method

use of com.alibaba.alink.operator.common.nlp.Method in project Alink by alibaba.

the class KeywordsExtractionBatchOp method linkFrom.

@Override
public KeywordsExtractionBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final String docId = "doc_alink_id";
    String selectedColName = this.getSelectedCol();
    TableUtil.assertSelectedColExist(in.getColNames(), selectedColName);
    String outputColName = this.getOutputCol();
    if (null == outputColName) {
        outputColName = selectedColName;
    }
    OutputColsHelper outputColsHelper = new OutputColsHelper(in.getSchema(), outputColName, Types.STRING);
    final Integer topN = this.getTopN();
    Method method = this.getMethod();
    BatchOperator inWithId = new TableSourceBatchOp(AppendIdBatchOp.appendId(in.getDataSet(), in.getSchema(), getMLEnvironmentId())).setMLEnvironmentId(getMLEnvironmentId());
    DataSet<Row> weights;
    StopWordsRemoverBatchOp filterOp = new StopWordsRemoverBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol(selectedColName).setOutputCol("selectedColName");
    BatchOperator filtered = filterOp.linkFrom(inWithId);
    switch(method) {
        case TF_IDF:
            {
                DocWordCountBatchOp wordCount = new DocWordCountBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setContentCol("selectedColName");
                TfidfBatchOp tfIdf = new TfidfBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setWordCol("word").setCountCol("cnt");
                BatchOperator op = filtered.link(wordCount).link(tfIdf);
                weights = op.select(AppendIdBatchOp.appendIdColName + ", " + "word, tfidf").getDataSet();
                break;
            }
        case TEXT_RANK:
            {
                DataSet<Row> data = filtered.select(AppendIdBatchOp.appendIdColName + ", selectedColName").getDataSet();
                // Initialize the TextRank class, which runs the text rank algorithm.
                final Params params = getParams();
                weights = data.flatMap(new FlatMapFunction<Row, Row>() {

                    private static final long serialVersionUID = -4083643981693873537L;

                    @Override
                    public void flatMap(Row row, Collector<Row> collector) throws Exception {
                        // For each row, apply the text rank algorithm to get the key words.
                        Row[] out = TextRank.getKeyWords(row, params.get(KeywordsExtractionParams.DAMPING_FACTOR), params.get(KeywordsExtractionParams.WINDOW_SIZE), params.get(KeywordsExtractionParams.MAX_ITER), params.get(KeywordsExtractionParams.EPSILON));
                        for (int i = 0; i < out.length; i++) {
                            collector.collect(out[i]);
                        }
                    }
                });
                break;
            }
        default:
            {
                throw new RuntimeException("Not support this type!");
            }
    }
    DataSet<Row> res = weights.groupBy(new KeySelector<Row, String>() {

        private static final long serialVersionUID = 801794449492798203L;

        @Override
        public String getKey(Row row) {
            Object obj = row.getField(0);
            if (obj == null) {
                return "NULL";
            }
            return row.getField(0).toString();
        }
    }).reduceGroup(new GroupReduceFunction<Row, Row>() {

        private static final long serialVersionUID = -4051509261188494119L;

        @Override
        public void reduce(Iterable<Row> rows, Collector<Row> collector) {
            List<Row> list = new ArrayList<>();
            for (Row row : rows) {
                list.add(row);
            }
            Collections.sort(list, new Comparator<Row>() {

                @Override
                public int compare(Row row1, Row row2) {
                    Double v1 = (double) row1.getField(2);
                    Double v2 = (double) row2.getField(2);
                    return v2.compareTo(v1);
                }
            });
            int len = Math.min(list.size(), topN);
            Row out = new Row(2);
            StringBuilder builder = new StringBuilder();
            for (int i = 0; i < len; i++) {
                builder.append(list.get(i).getField(1).toString());
                if (i != len - 1) {
                    builder.append(" ");
                }
            }
            out.setField(0, list.get(0).getField(0));
            out.setField(1, builder.toString());
            collector.collect(out);
        }
    });
    // Set the output into table.
    Table tmpTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), res, new String[] { docId, outputColName }, new TypeInformation[] { Types.LONG, Types.STRING });
    StringBuilder selectClause = new StringBuilder("a." + outputColName);
    String[] keepColNames = outputColsHelper.getReservedColumns();
    for (int i = 0; i < keepColNames.length; i++) {
        selectClause.append("," + keepColNames[i]);
    }
    JoinBatchOp join = new JoinBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setType("join").setSelectClause(selectClause.toString()).setJoinPredicate(docId + "=" + AppendIdBatchOp.appendIdColName);
    this.setOutputTable(join.linkFrom(new TableSourceBatchOp(tmpTable).setMLEnvironmentId(getMLEnvironmentId()), inWithId).getOutputTable());
    return this;
}
Also used : DataSet(org.apache.flink.api.java.DataSet) KeySelector(org.apache.flink.api.java.functions.KeySelector) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) Comparator(java.util.Comparator) ArrayList(java.util.ArrayList) List(java.util.List) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) Table(org.apache.flink.table.api.Table) JoinBatchOp(com.alibaba.alink.operator.batch.sql.JoinBatchOp) KeywordsExtractionParams(com.alibaba.alink.params.nlp.KeywordsExtractionParams) Params(org.apache.flink.ml.api.misc.param.Params) Method(com.alibaba.alink.operator.common.nlp.Method) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Row(org.apache.flink.types.Row)

Aggregations

OutputColsHelper (com.alibaba.alink.common.utils.OutputColsHelper)1 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 JoinBatchOp (com.alibaba.alink.operator.batch.sql.JoinBatchOp)1 Method (com.alibaba.alink.operator.common.nlp.Method)1 KeywordsExtractionParams (com.alibaba.alink.params.nlp.KeywordsExtractionParams)1 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 List (java.util.List)1 DataSet (org.apache.flink.api.java.DataSet)1 KeySelector (org.apache.flink.api.java.functions.KeySelector)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Table (org.apache.flink.table.api.Table)1 Row (org.apache.flink.types.Row)1