Search in sources :

Example 21 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class ModelExporterUtils method preOrderSerialize.

private static BatchOperator<?> preOrderSerialize(StageNode[] stages, BatchOperator<?> packed, final TableSchema schema, final int offset) {
    if (stages == null || stages.length == 0) {
        return packed;
    }
    final int len = schema.getFieldTypes().length;
    final long[] id = new long[1];
    final BatchOperator<?>[] localPacked = new BatchOperator<?>[] { packed };
    Consumer<Integer> serializeModelData = lp -> {
        StageNode stageNode = stages[lp];
        if (stageNode.parent >= 0 && stageNode.schemaIndices != null && stageNode.children == null && stageNode.stage instanceof ModelBase<?>) {
            final long localId = id[0];
            final int[] localSchemaIndices = stageNode.schemaIndices;
            DataSet<Row> modelData = ((ModelBase<?>) stageNode.stage).getModelData().getDataSet().map(new MapFunction<Row, Row>() {

                private static final long serialVersionUID = 5218543921039328938L;

                @Override
                public Row map(Row value) {
                    Row ret = new Row(len);
                    ret.setField(0, localId);
                    for (int i = 0; i < localSchemaIndices.length; ++i) {
                        ret.setField(localSchemaIndices[i] + offset, value.getField(i));
                    }
                    return ret;
                }
            }).returns(new RowTypeInfo(schema.getFieldTypes()));
            localPacked[0] = new TableSourceBatchOp(DataSetConversionUtil.toTable(localPacked[0].getMLEnvironmentId(), localPacked[0].getDataSet().union(modelData), schema)).setMLEnvironmentId(localPacked[0].getMLEnvironmentId());
        }
        id[0]++;
    };
    preOrder(stages, serializeModelData);
    return localPacked[0];
}
Also used : Arrays(java.util.Arrays) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Tuple2(org.apache.flink.api.java.tuple.Tuple2) MLEnvironmentFactory(com.alibaba.alink.common.MLEnvironmentFactory) AkStream(com.alibaba.alink.common.io.filesystem.AkStream) JsonConverter(com.alibaba.alink.common.utils.JsonConverter) RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) MapFunction(org.apache.flink.api.common.functions.MapFunction) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) ComboModelMapper(com.alibaba.alink.common.mapper.ComboModelMapper) AkUtils(com.alibaba.alink.common.io.filesystem.AkUtils) DataSet(org.apache.flink.api.java.DataSet) RecommenderUtil(com.alibaba.alink.pipeline.recommendation.RecommenderUtil) Map(java.util.Map) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) TableSchema(org.apache.flink.table.api.TableSchema) Preconditions(org.apache.flink.util.Preconditions) Collectors(java.util.stream.Collectors) InvocationTargetException(java.lang.reflect.InvocationTargetException) TypeReference(org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference) FilterFunction(org.apache.flink.api.common.functions.FilterFunction) List(java.util.List) DataSetConversionUtil(com.alibaba.alink.common.utils.DataSetConversionUtil) ModelStreamScanParams(com.alibaba.alink.params.ModelStreamScanParams) ComboMapper(com.alibaba.alink.common.mapper.ComboMapper) Row(org.apache.flink.types.Row) IntStream(java.util.stream.IntStream) TableUtil(com.alibaba.alink.common.utils.TableUtil) Lists(org.apache.flink.shaded.guava18.com.google.common.collect.Lists) MapperChain(com.alibaba.alink.common.mapper.MapperChain) BaseRecommender(com.alibaba.alink.pipeline.recommendation.BaseRecommender) ArrayUtils(org.apache.commons.lang3.ArrayUtils) HashMap(java.util.HashMap) PipelineModelMapper.getExtendModelSchema(com.alibaba.alink.common.mapper.PipelineModelMapper.getExtendModelSchema) ArrayList(java.util.ArrayList) FileProcFunction(com.alibaba.alink.common.io.filesystem.AkUtils.FileProcFunction) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) CsvUtil(com.alibaba.alink.operator.common.io.csv.CsvUtil) LinkedList(java.util.LinkedList) FlinkTypeConverter(com.alibaba.alink.operator.common.io.types.FlinkTypeConverter) Types(org.apache.flink.api.common.typeinfo.Types) IOException(java.io.IOException) AkReader(com.alibaba.alink.common.io.filesystem.AkStream.AkReader) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) Consumer(java.util.function.Consumer) FilePath(com.alibaba.alink.common.io.filesystem.FilePath) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) Comparator(java.util.Comparator) Params(org.apache.flink.ml.api.misc.param.Params) Collections(java.util.Collections) Mapper(com.alibaba.alink.common.mapper.Mapper) DataSet(org.apache.flink.api.java.DataSet) RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Row(org.apache.flink.types.Row)

Example 22 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class BaseTuning method findBestTVSplit.

protected Tuple2<Pipeline, Report> findBestTVSplit(BatchOperator<?> in, double ratio, PipelineCandidatesBase candidates) {
    int nIter = candidates.size();
    SplitBatchOp sbo = new SplitBatchOp().setFraction(ratio).setMLEnvironmentId(getMLEnvironmentId()).linkFrom(new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), shuffle(in.getDataSet()), in.getSchema())).setMLEnvironmentId(getMLEnvironmentId()));
    int bestIdx = -1;
    double bestMetric = 0.;
    ArrayList<Double> experienceScores = new ArrayList<>(nIter);
    List<Report.ReportElement> reportElements = new ArrayList<>();
    for (int i = 0; i < nIter; i++) {
        Tuple2<Pipeline, List<Tuple3<Integer, ParamInfo, Object>>> cur;
        try {
            cur = candidates.get(i, experienceScores);
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
        double metric = Double.NaN;
        try {
            metric = tuningEvaluator.evaluate(cur.f0.fit(sbo).transform(sbo.getSideOutput(0)));
        } catch (Exception ex) {
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("BestTVSplit, i: %d, best: %f, metric: %f, exception: %s", i, bestMetric, metric, ExceptionUtils.stringifyException(ex)));
            }
            experienceScores.add(i, metric);
            reportElements.add(new Report.ReportElement(cur.f0, cur.f1, metric, ExceptionUtils.stringifyException(ex)));
            continue;
        }
        experienceScores.add(i, metric);
        if (Double.isNaN(metric)) {
            reportElements.add(new Report.ReportElement(cur.f0, cur.f1, metric, "Metric is nan."));
            continue;
        }
        reportElements.add(new Report.ReportElement(cur.f0, cur.f1, metric));
        if (bestIdx == -1) {
            bestMetric = metric;
            bestIdx = i;
        } else {
            if ((tuningEvaluator.isLargerBetter() && bestMetric < metric) || (!tuningEvaluator.isLargerBetter() && bestMetric > metric)) {
                bestMetric = metric;
                bestIdx = i;
            }
        }
        if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
            System.out.println(String.format("BestTVSplit, i: %d, best: %f, metric: %f", i, bestMetric, metric));
        }
    }
    if (bestIdx < 0) {
        throw new RuntimeException("Can not find a best model. Report: " + new Report(tuningEvaluator, reportElements).toPrettyJson());
    }
    try {
        return Tuple2.of(candidates.get(bestIdx, experienceScores).f0, new Report(tuningEvaluator, reportElements));
    } catch (CloneNotSupportedException e) {
        throw new RuntimeException(e);
    }
}
Also used : ArrayList(java.util.ArrayList) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) SplitBatchOp(com.alibaba.alink.operator.batch.dataproc.SplitBatchOp) Pipeline(com.alibaba.alink.pipeline.Pipeline) ArrayList(java.util.ArrayList) List(java.util.List) ParamInfo(org.apache.flink.ml.api.misc.param.ParamInfo)

Example 23 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class BaseTuning method kFoldCv.

private Tuple2<Double, String> kFoldCv(DataSet<Tuple2<Integer, Row>> splitData, Pipeline pipeline, TableSchema schema, int k) {
    double ret = 0.;
    int validSize = 0;
    StringBuilder reason = new StringBuilder();
    for (int i = 0; i < k; ++i) {
        final int loop = i;
        DataSet<Row> trainInput = splitData.filter(new FilterFunction<Tuple2<Integer, Row>>() {

            private static final long serialVersionUID = 2249884521437544236L;

            @Override
            public boolean filter(Tuple2<Integer, Row> value) {
                return value.f0 != loop;
            }
        }).map(new MapFunction<Tuple2<Integer, Row>, Row>() {

            private static final long serialVersionUID = 2618229645786221757L;

            @Override
            public Row map(Tuple2<Integer, Row> value) {
                return value.f1;
            }
        });
        DataSet<Row> testInput = splitData.filter(new FilterFunction<Tuple2<Integer, Row>>() {

            private static final long serialVersionUID = 5811166054549336470L;

            @Override
            public boolean filter(Tuple2<Integer, Row> value) {
                return value.f0 == loop;
            }
        }).map(new MapFunction<Tuple2<Integer, Row>, Row>() {

            private static final long serialVersionUID = -1760709990316111721L;

            @Override
            public Row map(Tuple2<Integer, Row> value) {
                return value.f1;
            }
        });
        PipelineModel model = pipeline.fit(new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), trainInput, schema)).setMLEnvironmentId(getMLEnvironmentId()));
        double localMetric = Double.NaN;
        try {
            localMetric = tuningEvaluator.evaluate(model.transform(new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), testInput, schema))));
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("kFoldCv, k: %d, i: %d, metric: %f", k, i, localMetric));
            }
        } catch (Exception ex) {
            if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
                System.out.println(String.format("kFoldCv err, k: %d, i: %d, metric: %f, exception: %s", k, i, localMetric, ExceptionUtils.stringifyException(ex)));
            }
            reason.append(ExceptionUtils.stringifyException(ex)).append("\n");
            continue;
        }
        ret += localMetric;
        validSize++;
    }
    if (validSize == 0) {
        reason.append("valid size is zero.").append("\n");
        return Tuple2.of(Double.NaN, reason.toString());
    }
    ret /= validSize;
    if (validSize > 0) {
        return Tuple2.of(ret, reason.toString());
    } else {
        reason.append("valid size if negative.").append("\n");
        return Tuple2.of(Double.NaN, reason.toString());
    }
}
Also used : FilterFunction(org.apache.flink.api.common.functions.FilterFunction) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) PipelineModel(com.alibaba.alink.pipeline.PipelineModel) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Row(org.apache.flink.types.Row)

Example 24 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class DerbyCatalogTest method sinkBatch.

@Test
public void sinkBatch() throws Exception {
    Row[] rows = new Row[] { Row.of(0L, "string", new Date(0), new BigDecimal("0.00"), 0.0, 0.0f, 0, (short) 0, new Time(0), new Timestamp(0), new byte[] { 0, 1 }) };
    MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(Arrays.asList(rows), new TableSchema(new String[] { "col_long", "col_string", "col_date", "col_bigdecimal", "col_double", "col_float", "col_int", "col_short", "col_time", "col_timestamp", "col_varcharforbit" }, new TypeInformation<?>[] { Types.LONG, Types.STRING, Types.SQL_DATE, Types.BIG_DEC, Types.DOUBLE, Types.FLOAT, Types.INT, Types.SHORT, Types.SQL_TIME, Types.SQL_TIMESTAMP, Types.PRIMITIVE_ARRAY(Types.BYTE) }));
    derby.sinkBatch(new ObjectPath(DERBY_SCHEMA, DERBY_DB_TABLE), memSourceBatchOp.getOutputTable(), new Params(), memSourceBatchOp.getMLEnvironmentId());
    BatchOperator.execute();
    Assert.assertFalse(new TableSourceBatchOp(derby.sourceBatch(new ObjectPath(DERBY_SCHEMA, DERBY_DB_TABLE), new Params(), MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID)).collect().isEmpty());
}
Also used : ObjectPath(org.apache.flink.table.catalog.ObjectPath) TableSchema(org.apache.flink.table.api.TableSchema) Params(org.apache.flink.ml.api.misc.param.Params) Time(java.sql.Time) Timestamp(java.sql.Timestamp) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) Date(java.sql.Date) BigDecimal(java.math.BigDecimal) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) Row(org.apache.flink.types.Row) Test(org.junit.Test)

Example 25 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class KeywordsExtractionBatchOp method linkFrom.

@Override
public KeywordsExtractionBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final String docId = "doc_alink_id";
    String selectedColName = this.getSelectedCol();
    TableUtil.assertSelectedColExist(in.getColNames(), selectedColName);
    String outputColName = this.getOutputCol();
    if (null == outputColName) {
        outputColName = selectedColName;
    }
    OutputColsHelper outputColsHelper = new OutputColsHelper(in.getSchema(), outputColName, Types.STRING);
    final Integer topN = this.getTopN();
    Method method = this.getMethod();
    BatchOperator inWithId = new TableSourceBatchOp(AppendIdBatchOp.appendId(in.getDataSet(), in.getSchema(), getMLEnvironmentId())).setMLEnvironmentId(getMLEnvironmentId());
    DataSet<Row> weights;
    StopWordsRemoverBatchOp filterOp = new StopWordsRemoverBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol(selectedColName).setOutputCol("selectedColName");
    BatchOperator filtered = filterOp.linkFrom(inWithId);
    switch(method) {
        case TF_IDF:
            {
                DocWordCountBatchOp wordCount = new DocWordCountBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setContentCol("selectedColName");
                TfidfBatchOp tfIdf = new TfidfBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setDocIdCol(AppendIdBatchOp.appendIdColName).setWordCol("word").setCountCol("cnt");
                BatchOperator op = filtered.link(wordCount).link(tfIdf);
                weights = op.select(AppendIdBatchOp.appendIdColName + ", " + "word, tfidf").getDataSet();
                break;
            }
        case TEXT_RANK:
            {
                DataSet<Row> data = filtered.select(AppendIdBatchOp.appendIdColName + ", selectedColName").getDataSet();
                // Initialize the TextRank class, which runs the text rank algorithm.
                final Params params = getParams();
                weights = data.flatMap(new FlatMapFunction<Row, Row>() {

                    private static final long serialVersionUID = -4083643981693873537L;

                    @Override
                    public void flatMap(Row row, Collector<Row> collector) throws Exception {
                        // For each row, apply the text rank algorithm to get the key words.
                        Row[] out = TextRank.getKeyWords(row, params.get(KeywordsExtractionParams.DAMPING_FACTOR), params.get(KeywordsExtractionParams.WINDOW_SIZE), params.get(KeywordsExtractionParams.MAX_ITER), params.get(KeywordsExtractionParams.EPSILON));
                        for (int i = 0; i < out.length; i++) {
                            collector.collect(out[i]);
                        }
                    }
                });
                break;
            }
        default:
            {
                throw new RuntimeException("Not support this type!");
            }
    }
    DataSet<Row> res = weights.groupBy(new KeySelector<Row, String>() {

        private static final long serialVersionUID = 801794449492798203L;

        @Override
        public String getKey(Row row) {
            Object obj = row.getField(0);
            if (obj == null) {
                return "NULL";
            }
            return row.getField(0).toString();
        }
    }).reduceGroup(new GroupReduceFunction<Row, Row>() {

        private static final long serialVersionUID = -4051509261188494119L;

        @Override
        public void reduce(Iterable<Row> rows, Collector<Row> collector) {
            List<Row> list = new ArrayList<>();
            for (Row row : rows) {
                list.add(row);
            }
            Collections.sort(list, new Comparator<Row>() {

                @Override
                public int compare(Row row1, Row row2) {
                    Double v1 = (double) row1.getField(2);
                    Double v2 = (double) row2.getField(2);
                    return v2.compareTo(v1);
                }
            });
            int len = Math.min(list.size(), topN);
            Row out = new Row(2);
            StringBuilder builder = new StringBuilder();
            for (int i = 0; i < len; i++) {
                builder.append(list.get(i).getField(1).toString());
                if (i != len - 1) {
                    builder.append(" ");
                }
            }
            out.setField(0, list.get(0).getField(0));
            out.setField(1, builder.toString());
            collector.collect(out);
        }
    });
    // Set the output into table.
    Table tmpTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), res, new String[] { docId, outputColName }, new TypeInformation[] { Types.LONG, Types.STRING });
    StringBuilder selectClause = new StringBuilder("a." + outputColName);
    String[] keepColNames = outputColsHelper.getReservedColumns();
    for (int i = 0; i < keepColNames.length; i++) {
        selectClause.append("," + keepColNames[i]);
    }
    JoinBatchOp join = new JoinBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setType("join").setSelectClause(selectClause.toString()).setJoinPredicate(docId + "=" + AppendIdBatchOp.appendIdColName);
    this.setOutputTable(join.linkFrom(new TableSourceBatchOp(tmpTable).setMLEnvironmentId(getMLEnvironmentId()), inWithId).getOutputTable());
    return this;
}
Also used : DataSet(org.apache.flink.api.java.DataSet) KeySelector(org.apache.flink.api.java.functions.KeySelector) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) Comparator(java.util.Comparator) ArrayList(java.util.ArrayList) List(java.util.List) OutputColsHelper(com.alibaba.alink.common.utils.OutputColsHelper) Table(org.apache.flink.table.api.Table) JoinBatchOp(com.alibaba.alink.operator.batch.sql.JoinBatchOp) KeywordsExtractionParams(com.alibaba.alink.params.nlp.KeywordsExtractionParams) Params(org.apache.flink.ml.api.misc.param.Params) Method(com.alibaba.alink.operator.common.nlp.Method) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Row(org.apache.flink.types.Row)

Aggregations

TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)39 Row (org.apache.flink.types.Row)29 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)22 Test (org.junit.Test)18 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)12 TableSourceStreamOp (com.alibaba.alink.operator.stream.source.TableSourceStreamOp)10 Params (org.apache.flink.ml.api.misc.param.Params)10 List (java.util.List)8 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)8 StreamOperator (com.alibaba.alink.operator.stream.StreamOperator)6 ArrayList (java.util.ArrayList)6 TableSchema (org.apache.flink.table.api.TableSchema)6 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)5 Comparator (java.util.Comparator)4 HashMap (java.util.HashMap)4 MapFunction (org.apache.flink.api.common.functions.MapFunction)4 DataSet (org.apache.flink.api.java.DataSet)4 Mapper (com.alibaba.alink.common.mapper.Mapper)3 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)3 PipelineModelMapper (com.alibaba.alink.common.mapper.PipelineModelMapper)3