Search in sources :

Example 11 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class GeoKMeansTest method before.

@Before
public void before() {
    Row[] rows = new Row[] { Row.of(0, 0, 0), Row.of(1, 8, 8), Row.of(2, 1, 2), Row.of(3, 9, 10), Row.of(4, 3, 1), Row.of(5, 10, 7) };
    inputBatchOp = new TableSourceBatchOp(MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "id", "f0", "f1" }));
    inputStreamOp = new TableSourceStreamOp(MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "id", "f0", "f1" }));
    expectedPrediction = new double[] { 185.31, 117.08, 117.18, 183.04, 185.32, 183.70 };
}
Also used : Row(org.apache.flink.types.Row) TableSourceStreamOp(com.alibaba.alink.operator.stream.source.TableSourceStreamOp) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) Before(org.junit.Before)

Example 12 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class CorrelationBatchOp method linkFrom.

@Override
public CorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    String[] selectedColNames = this.getParams().get(SELECTED_COLS);
    if (selectedColNames == null) {
        selectedColNames = in.getColNames();
    }
    // check col types must be double or bigint
    TableUtil.assertNumericalCols(in.getSchema(), selectedColNames);
    Method corrType = getMethod();
    if (Method.PEARSON == corrType) {
        DataSet<Tuple2<TableSummary, CorrelationResult>> srt = StatisticsHelper.pearsonCorrelation(in, selectedColNames);
        DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<TableSummary, CorrelationResult>, Row>() {

            private static final long serialVersionUID = -4498296161046449646L;

            @Override
            public void flatMap(Tuple2<TableSummary, CorrelationResult> summary, Collector<Row> collector) {
                new CorrelationDataConverter().save(summary.f1, collector);
            }
        });
        this.setOutput(result, new CorrelationDataConverter().getModelSchema());
    } else {
        DataSet<Row> data = inputs[0].select(selectedColNames).getDataSet();
        DataSet<Row> rank = SpearmanCorrelation.calcRank(data, false);
        TypeInformation[] colTypes = new TypeInformation[selectedColNames.length];
        for (int i = 0; i < colTypes.length; i++) {
            colTypes[i] = Types.DOUBLE;
        }
        BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, selectedColNames, colTypes)).setMLEnvironmentId(getMLEnvironmentId());
        CorrelationBatchOp corrBatchOp = new CorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCols(selectedColNames);
        rankOp.link(corrBatchOp);
        this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
    }
    return this;
}
Also used : CorrelationDataConverter(com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationDataConverter) CorrelationResult(com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Row(org.apache.flink.types.Row) TableSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.TableSummary)

Example 13 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class VectorCorrelationBatchOp method linkFrom.

@Override
public VectorCorrelationBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    String vectorColName = getSelectedCol();
    Method corrType = getMethod();
    if (Method.PEARSON == corrType) {
        DataSet<Tuple2<BaseVectorSummary, CorrelationResult>> srt = StatisticsHelper.vectorPearsonCorrelation(in, vectorColName);
        // block
        DataSet<Row> result = srt.flatMap(new FlatMapFunction<Tuple2<BaseVectorSummary, CorrelationResult>, Row>() {

            private static final long serialVersionUID = 2134644397476490118L;

            @Override
            public void flatMap(Tuple2<BaseVectorSummary, CorrelationResult> srt, Collector<Row> collector) throws Exception {
                new CorrelationDataConverter().save(srt.f1, collector);
            }
        });
        this.setOutput(result, new CorrelationDataConverter().getModelSchema());
    } else {
        DataSet<Row> data = StatisticsHelper.transformToColumns(in, null, vectorColName, null);
        DataSet<Row> rank = SpearmanCorrelation.calcRank(data, true);
        BatchOperator rankOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), rank, new String[] { "col" }, new TypeInformation[] { Types.STRING })).setMLEnvironmentId(getMLEnvironmentId());
        VectorCorrelationBatchOp corrBatchOp = new VectorCorrelationBatchOp().setMLEnvironmentId(getMLEnvironmentId()).setSelectedCol("col");
        rankOp.link(corrBatchOp);
        this.setOutput(corrBatchOp.getDataSet(), corrBatchOp.getSchema());
    }
    return this;
}
Also used : CorrelationDataConverter(com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationDataConverter) CorrelationResult(com.alibaba.alink.operator.common.statistics.basicstatistic.CorrelationResult) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Tuple2(org.apache.flink.api.java.tuple.Tuple2) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) Row(org.apache.flink.types.Row)

Example 14 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class PipelineModel method collectLocalPredictor.

@Override
public LocalPredictor collectLocalPredictor(TableSchema inputSchema) throws Exception {
    if (params.get(ModelStreamScanParams.MODEL_STREAM_FILE_PATH) != null) {
        BatchOperator<?> modelSave = ModelExporterUtils.serializePipelineStages(Arrays.asList(transformers), params);
        TableSchema extendSchema = getOutSchema(this, inputSchema);
        BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(modelSave.getMLEnvironmentId(), modelSave.getDataSet().map(new PipelineModelMapper.ExtendPipelineModelRow(extendSchema.getFieldNames().length + 1)), PipelineModelMapper.getExtendModelSchema(modelSave.getSchema(), extendSchema.getFieldNames(), extendSchema.getFieldTypes())));
        List<Row> modelRows = model.collect();
        ModelMapper mapper = new PipelineModelMapper(model.getSchema(), inputSchema, this.params);
        mapper.loadModel(modelRows);
        return new LocalPredictor(mapper);
    }
    if (null == transformers || transformers.length == 0) {
        throw new RuntimeException("PipelineModel is empty.");
    }
    List<BatchOperator<?>> allModelData = new ArrayList<>();
    for (TransformerBase<?> transformer : transformers) {
        if (!(transformer instanceof LocalPredictable)) {
            throw new RuntimeException(transformer.getClass().toString() + " not support local predict.");
        }
        if (transformer instanceof ModelBase) {
            allModelData.add(((ModelBase<?>) transformer).getModelData());
        }
    }
    List<List<Row>> allModelDataRows;
    if (!allModelData.isEmpty()) {
        allModelDataRows = BatchOperator.collect(allModelData.toArray(new BatchOperator<?>[0]));
    } else {
        allModelDataRows = new ArrayList<>();
    }
    TableSchema schema = inputSchema;
    int numMapperModel = 0;
    List<Mapper> mappers = new ArrayList<>();
    for (TransformerBase<?> transformer : transformers) {
        Mapper mapper;
        if (transformer instanceof MapModel) {
            mapper = ModelExporterUtils.createMapperFromStage(transformer, ((MapModel<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
            numMapperModel++;
        } else if (transformer instanceof BaseRecommender) {
            mapper = ModelExporterUtils.createMapperFromStage(transformer, ((BaseRecommender<?>) transformer).modelData.getSchema(), schema, allModelDataRows.get(numMapperModel));
            numMapperModel++;
        } else {
            mapper = ModelExporterUtils.createMapperFromStage(transformer, null, schema, null);
        }
        mappers.add(mapper);
        schema = mapper.getOutputSchema();
    }
    return new LocalPredictor(mappers.toArray(new Mapper[0]));
}
Also used : TableSchema(org.apache.flink.table.api.TableSchema) ArrayList(java.util.ArrayList) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) Mapper(com.alibaba.alink.common.mapper.Mapper) BaseRecommender(com.alibaba.alink.pipeline.recommendation.BaseRecommender) ArrayList(java.util.ArrayList) List(java.util.List) Row(org.apache.flink.types.Row)

Example 15 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class BaseEasyTransferTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = inputs[0];
    Params params = getParams();
    TaskType taskType = params.get(HasTaskType.TASK_TYPE);
    String labelCol = params.get(HasLabelCol.LABEL_COL);
    TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
    DataSet<List<Object>> sortedLabels = null;
    if (TaskType.CLASSIFICATION.equals(taskType)) {
        sortedLabels = in.select(labelCol).distinct().getDataSet().reduceGroup(new SortLabelsReduceGroupFunction());
        in = mapLabelToIntIndex(in, labelCol, sortedLabels);
    }
    BertTokenizer bertTokenizer = new BertTokenizer(params.clone()).set(HasMaxSeqLengthDefaultAsNull.MAX_SEQ_LENGTH, params.get(HasMaxSeqLength.MAX_SEQ_LENGTH));
    PipelineModel preprocessPipelineMode = new PipelineModel(bertTokenizer);
    in = preprocessPipelineMode.transform(in);
    BatchOperator<?> preprocessPipelineModelOp = preprocessPipelineMode.save();
    String preprocessPipelineModelSchemaStr = CsvUtil.schema2SchemaStr(preprocessPipelineModelOp.getSchema());
    Map<String, String> userParams = new HashMap<>();
    String bertModelName = params.get(HasBertModelName.BERT_MODEL_NAME);
    String bertModelCkptPath = params.contains(HasModelPath.MODEL_PATH) && (null != params.get(HasModelPath.MODEL_PATH)) ? params.get(HasModelPath.MODEL_PATH) : BertResources.getBertModelCkpt(bertModelName);
    String checkpointFilePath = params.get(HasCheckpointFilePathDefaultAsNull.CHECKPOINT_FILE_PATH);
    if (!StringUtils.isNullOrWhitespaceOnly(checkpointFilePath)) {
        userParams.put("model_dir", checkpointFilePath);
    }
    ExternalFilesConfig externalFilesConfig = params.contains(HasUserFiles.USER_FILES) ? ExternalFilesConfig.fromJson(params.get(HasUserFiles.USER_FILES)) : new ExternalFilesConfig();
    if (PythonFileUtils.isLocalFile(bertModelCkptPath)) {
        // should be a directory
        userParams.put("pretrained_ckpt_path", bertModelCkptPath.substring("file://".length()));
    } else {
        externalFilesConfig.addFilePaths(bertModelCkptPath);
        userParams.put("pretrained_ckpt_path", PythonFileUtils.getCompressedFileName(bertModelCkptPath));
    }
    Map<String, Map<String, Object>> config = getConfig(getParams(), false);
    String configJson = JsonConverter.toJson(config);
    LOG.info("EasyTransfer config: {}", configJson);
    if (AlinkGlobalConfiguration.isPrintProcessInfo()) {
        System.out.println("EasyTransfer config: " + configJson);
    }
    BertTaskName taskName = params.get(HasTaskName.TASK_NAME);
    userParams.put("app_name", taskName.name());
    EasyTransferConfigTrainBatchOp trainBatchOp = new EasyTransferConfigTrainBatchOp().setSelectedCols(ArrayUtils.add(SAFE_MODEL_INPUTS, labelCol)).setConfigJson(configJson).setUserFiles(externalFilesConfig).setUserParams(JsonConverter.toJson(userParams)).setNumWorkers(params.get(HasNumWorkersDefaultAsNull.NUM_WORKERS)).setNumPSs(params.get(HasNumPssDefaultAsNull.NUM_PSS)).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV)).setIntraOpParallelism(params.get(HasIntraOpParallelism.INTRA_OP_PARALLELISM)).setMLEnvironmentId(getMLEnvironmentId());
    BatchOperator<?>[] tfInputs;
    tfInputs = new BatchOperator<?>[inputs.length];
    tfInputs[0] = in;
    System.arraycopy(inputs, 1, tfInputs, 1, inputs.length - 1);
    BatchOperator<?> tfModel = trainBatchOp.linkFrom(tfInputs);
    String tfOutputSignatureDef = EasyTransferUtils.getTfOutputSignatureDef(taskType);
    MapPartitionOperator<Row, Row> constructModelMapPartitionOperator = tfModel.getDataSet().partitionCustom(new Partitioner<Long>() {

        @Override
        public int partition(Long key, int numPartitions) {
            return 0;
        }
    }, 0).mapPartition(new ConstructModelMapPartitionFunction(params, SAFE_MODEL_INPUTS, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, preprocessPipelineModelSchemaStr)).withBroadcastSet(preprocessPipelineModelOp.getDataSet(), PREPROCESS_PIPELINE_MODEL_BC_NAME);
    DataSet<Row> modelDataSet = TaskType.CLASSIFICATION.equals(taskType) ? constructModelMapPartitionOperator.withBroadcastSet(sortedLabels, SORTED_LABELS_BC_NAME) : constructModelMapPartitionOperator;
    BatchOperator<?> modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    this.setOutputTable(modelOp.getOutputTable());
    return (T) this;
}
Also used : HashMap(java.util.HashMap) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) PipelineModel(com.alibaba.alink.pipeline.PipelineModel) BertTokenizer(com.alibaba.alink.pipeline.nlp.BertTokenizer) HasTaskType(com.alibaba.alink.params.dl.HasTaskType) SortLabelsReduceGroupFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.SortLabelsReduceGroupFunction) List(java.util.List) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelClassificationModelDataConverter(com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Row(org.apache.flink.types.Row) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap) ConstructModelMapPartitionFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelMapPartitionFunction)

Aggregations

TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)39 Row (org.apache.flink.types.Row)29 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)22 Test (org.junit.Test)18 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)12 TableSourceStreamOp (com.alibaba.alink.operator.stream.source.TableSourceStreamOp)10 Params (org.apache.flink.ml.api.misc.param.Params)10 List (java.util.List)8 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)8 StreamOperator (com.alibaba.alink.operator.stream.StreamOperator)6 ArrayList (java.util.ArrayList)6 TableSchema (org.apache.flink.table.api.TableSchema)6 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)5 Comparator (java.util.Comparator)4 HashMap (java.util.HashMap)4 MapFunction (org.apache.flink.api.common.functions.MapFunction)4 DataSet (org.apache.flink.api.java.DataSet)4 Mapper (com.alibaba.alink.common.mapper.Mapper)3 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)3 PipelineModelMapper (com.alibaba.alink.common.mapper.PipelineModelMapper)3