Search in sources :

Example 16 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class BaseKerasSequentialTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = inputs[0];
    Params params = getParams();
    TaskType taskType = params.get(HasTaskType.TASK_TYPE);
    boolean isReg = TaskType.REGRESSION.equals(taskType);
    String tensorCol = getTensorCol();
    String labelCol = getLabelCol();
    TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
    DataSet<List<Object>> sortedLabels = null;
    BatchOperator<?> numLabelsOp = null;
    if (!isReg) {
        sortedLabels = in.select(labelCol).getDataSet().mapPartition(new MapPartitionFunction<Row, Object>() {

            @Override
            public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
                Set<Object> distinctValue = new HashSet<>();
                for (Row row : iterable) {
                    distinctValue.add(row.getField(0));
                }
                for (Object obj : distinctValue) {
                    collector.collect(obj);
                }
            }
        }).reduceGroup(new GroupReduceFunction<Object, List<Object>>() {

            @Override
            public void reduce(Iterable<Object> iterable, Collector<List<Object>> collector) throws Exception {
                Set<Object> distinctValue = new TreeSet<>();
                for (Object obj : iterable) {
                    distinctValue.add(obj);
                }
                collector.collect(new ArrayList<>(distinctValue));
            }
        });
        in = CommonUtils.mapLabelToIndex(in, labelCol, sortedLabels);
        numLabelsOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), sortedLabels.map(new CountLabelsMapFunction()), new String[] { "count" }, new TypeInformation<?>[] { Types.INT })).setMLEnvironmentId(getMLEnvironmentId());
    }
    Boolean removeCheckpointBeforeTraining = getRemoveCheckpointBeforeTraining();
    if (null == removeCheckpointBeforeTraining) {
        // default to clean checkpoint
        removeCheckpointBeforeTraining = true;
    }
    Map<String, Object> modelConfig = new HashMap<>();
    modelConfig.put("layers", getLayers());
    Map<String, String> userParams = new HashMap<>();
    if (removeCheckpointBeforeTraining) {
        userParams.put(DLConstants.REMOVE_CHECKPOINT_BEFORE_TRAINING, "true");
    }
    userParams.put("tensor_cols", JsonConverter.toJson(new String[] { tensorCol }));
    userParams.put("label_col", labelCol);
    userParams.put("label_type", "float");
    userParams.put("batch_size", String.valueOf(getBatchSize()));
    userParams.put("num_epochs", String.valueOf(getNumEpochs()));
    userParams.put("model_config", JsonConverter.toJson(modelConfig));
    userParams.put("optimizer", getOptimizer());
    if (!StringUtils.isNullOrWhitespaceOnly(getCheckpointFilePath())) {
        userParams.put("model_dir", getCheckpointFilePath());
    }
    ExecutionEnvironment env = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
    if (env.getParallelism() == 1) {
        userParams.put("ALINK:ONLY_ONE_WORKER", "true");
    }
    userParams.put("validation_split", String.valueOf(getValidationSplit()));
    userParams.put("save_best_only", String.valueOf(getSaveBestOnly()));
    userParams.put("best_exporter_metric", getBestMetric());
    userParams.put("save_checkpoints_epochs", String.valueOf(getSaveCheckpointsEpochs()));
    if (params.contains(BaseKerasSequentialTrainParams.SAVE_CHECKPOINTS_SECS)) {
        userParams.put("save_checkpoints_secs", String.valueOf(getSaveCheckpointsSecs()));
    }
    TF2TableModelTrainBatchOp trainBatchOp = new TF2TableModelTrainBatchOp(params).setSelectedCols(tensorCol, labelCol).setUserFiles(RES_PY_FILES).setMainScriptFile(MAIN_SCRIPT_FILE_NAME).setUserParams(JsonConverter.toJson(userParams)).setIntraOpParallelism(getIntraOpParallelism()).setNumPSs(getNumPSs()).setNumWorkers(getNumWorkers()).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV));
    if (isReg) {
        trainBatchOp = trainBatchOp.linkFrom(in);
    } else {
        trainBatchOp = trainBatchOp.linkFrom(in, numLabelsOp);
    }
    String tfOutputSignatureDef = getTfOutputSignatureDef(taskType);
    FlatMapOperator<Row, Row> constructModelFlatMapOperator = new NumSeqSourceBatchOp().setFrom(0).setTo(0).setMLEnvironmentId(getMLEnvironmentId()).getDataSet().flatMap(new ConstructModelFlatMapFunction(params, new String[] { tensorCol }, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, null, true)).withBroadcastSet(trainBatchOp.getDataSet(), CommonUtils.TF_MODEL_BC_NAME);
    BatchOperator<?> modelOp;
    if (isReg) {
        DataSet<Row> modelDataSet = constructModelFlatMapOperator;
        modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelRegressionModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    } else {
        DataSet<Row> modelDataSet = constructModelFlatMapOperator.withBroadcastSet(sortedLabels, CommonUtils.SORTED_LABELS_BC_NAME);
        modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    }
    this.setOutputTable(modelOp.getOutputTable());
    return (T) this;
}
Also used : ExecutionEnvironment(org.apache.flink.api.java.ExecutionEnvironment) ConstructModelFlatMapFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelFlatMapFunction) DataSet(org.apache.flink.api.java.DataSet) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) HasTaskType(com.alibaba.alink.params.dl.HasTaskType) Collector(org.apache.flink.util.Collector) NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) GroupReduceFunction(org.apache.flink.api.common.functions.GroupReduceFunction) BaseKerasSequentialTrainParams(com.alibaba.alink.params.tensorflow.kerasequential.BaseKerasSequentialTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelRegressionModelDataConverter(com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter) TFTableModelClassificationModelDataConverter(com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter) TF2TableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp) CountLabelsMapFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.CountLabelsMapFunction) Row(org.apache.flink.types.Row)

Example 17 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class WordCountUtil method trans.

private static BatchOperator<?> trans(BatchOperator<?> in, String[] selectedColNames, String[] keepColNames, BatchOperator<?> indexedVocab, String wordColName, String idxColName, boolean isWord, String wordDelimiter) {
    String[] colnames = in.getColNames();
    TypeInformation<?>[] coltypes = in.getColTypes();
    int[] colIdxs = findColIdx(selectedColNames, colnames, coltypes);
    int[] appendIdxs = findAppendColIdx(keepColNames, colnames);
    // only 2 cols: word, idx
    DataSet<Row> voc = indexedVocab.select(wordColName + "," + idxColName).getDataSet();
    DataSet<Row> contentMapping = in.getDataSet().map(new GenContentMapping(colIdxs, appendIdxs, isWord, wordDelimiter)).withBroadcastSet(voc, "vocabulary");
    int transColSize = colIdxs.length;
    int keepColSize = keepColNames == null ? 0 : keepColNames.length;
    int outputColSize = transColSize + keepColSize;
    String[] names = new String[outputColSize];
    TypeInformation<?>[] types = new TypeInformation<?>[outputColSize];
    int i = 0;
    for (; i < transColSize; ++i) {
        names[i] = colnames[colIdxs[i]];
        types[i] = isWord ? Types.DOUBLE : Types.STRING;
    }
    for (; i < outputColSize; ++i) {
        names[i] = colnames[appendIdxs[i - transColSize]];
        types[i] = coltypes[appendIdxs[i - transColSize]];
    }
    return new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), contentMapping, names, types)).setMLEnvironmentId(in.getMLEnvironmentId());
}
Also used : Row(org.apache.flink.types.Row) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation)

Example 18 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class LdaTrainBatchOp method linkFrom.

@Override
public LdaTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    int parallelism = BatchOperator.getExecutionEnvironmentFromOps(in).getParallelism();
    long mlEnvId = getMLEnvironmentId();
    int numTopic = getTopicNum();
    int numIter = getNumIter();
    Integer seed = getRandomSeed();
    boolean setSeed = true;
    if (seed == null) {
        setSeed = false;
    }
    String vectorColName = getSelectedCol();
    Method optimizer = getMethod();
    final DataSet<DocCountVectorizerModelData> resDocCountModel = DocCountVectorizerTrainBatchOp.generateDocCountModel(getParams(), in);
    int index = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName);
    DataSet<Row> resRow = in.getDataSet().flatMap(new Document2Vector(index)).withBroadcastSet(resDocCountModel, "DocCountModel");
    TypeInformation<?>[] types = in.getColTypes().clone();
    types[index] = TypeInformation.of(SparseVector.class);
    BatchOperator trainData = new TableSourceBatchOp(DataSetConversionUtil.toTable(mlEnvId, resRow, in.getColNames(), types)).setMLEnvironmentId(mlEnvId);
    Tuple2<DataSet<Vector>, DataSet<BaseVectorSummary>> dataAndStat = StatisticsHelper.summaryHelper(trainData, null, vectorColName);
    if (setSeed) {
        DataSet<Tuple2<Long, Vector>> hashValue = dataAndStat.f0.map(new MapHashValue(seed)).partitionCustom(new Partitioner<Long>() {

            private static final long serialVersionUID = 5179898093029365608L;

            @Override
            public int partition(Long key, int numPartitions) {
                return (int) (Math.abs(key) % ((long) numPartitions));
            }
        }, 0);
        dataAndStat.f0 = hashValue.mapPartition(new MapPartitionFunction<Tuple2<Long, Vector>, Vector>() {

            private static final long serialVersionUID = -550512476573928350L;

            @Override
            public void mapPartition(Iterable<Tuple2<Long, Vector>> values, Collector<Vector> out) throws Exception {
                List<Tuple2<Long, Vector>> listValues = Lists.newArrayList(values);
                listValues.sort(new Comparator<Tuple2<Long, Vector>>() {

                    @Override
                    public int compare(Tuple2<Long, Vector> o1, Tuple2<Long, Vector> o2) {
                        int compare1 = o1.f0.compareTo(o2.f0);
                        if (compare1 == 0) {
                            String o1s = o1.f1.toString();
                            String o2s = o2.f1.toString();
                            return o1s.compareTo(o2s);
                        }
                        return compare1;
                    }
                });
                listValues.forEach(x -> out.collect(x.f1));
            }
        }).setParallelism(parallelism);
    }
    double beta = getParams().get(BETA);
    double alpha = getParams().get(ALPHA);
    int gammaShape = 250;
    switch(optimizer) {
        case EM:
            gibbsSample(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel, seed);
            break;
        case Online:
            online(dataAndStat, numTopic, numIter, alpha, beta, resDocCountModel, gammaShape, seed);
            break;
        default:
            throw new NotImplementedException("Optimizer not support.");
    }
    return this;
}
Also used : LdaUtil(com.alibaba.alink.operator.common.clustering.lda.LdaUtil) OnlineLogLikelihood(com.alibaba.alink.operator.common.clustering.lda.OnlineLogLikelihood) Arrays(java.util.Arrays) Tuple2(org.apache.flink.api.java.tuple.Tuple2) OnlineCorpusStep(com.alibaba.alink.operator.common.clustering.lda.OnlineCorpusStep) LdaTrainParams(com.alibaba.alink.params.clustering.LdaTrainParams) MapFunction(org.apache.flink.api.common.functions.MapFunction) WithModelInfoBatchOp(com.alibaba.alink.common.lazy.WithModelInfoBatchOp) DataSet(org.apache.flink.api.java.DataSet) NotImplementedException(org.apache.commons.lang.NotImplementedException) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Vector(com.alibaba.alink.common.linalg.Vector) RichMapPartitionFunction(org.apache.flink.api.common.functions.RichMapPartitionFunction) Table(org.apache.flink.table.api.Table) AllReduce(com.alibaba.alink.common.comqueue.communication.AllReduce) LdaModelMapper(com.alibaba.alink.operator.common.clustering.LdaModelMapper) UpdateLambdaAndAlpha(com.alibaba.alink.operator.common.clustering.lda.UpdateLambdaAndAlpha) EmCorpusStep(com.alibaba.alink.operator.common.clustering.lda.EmCorpusStep) List(java.util.List) IterativeComQueue(com.alibaba.alink.common.comqueue.IterativeComQueue) DataSetConversionUtil(com.alibaba.alink.common.utils.DataSetConversionUtil) LdaModelDataConverter(com.alibaba.alink.operator.common.clustering.LdaModelDataConverter) BaseVectorSummary(com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary) HashFunction(org.apache.flink.shaded.guava18.com.google.common.hash.HashFunction) MapPartitionFunction(org.apache.flink.api.common.functions.MapPartitionFunction) Row(org.apache.flink.types.Row) DocCountVectorizerModelMapper(com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelMapper) LdaVariable(com.alibaba.alink.operator.common.clustering.lda.LdaVariable) Hashing.murmur3_128(org.apache.flink.shaded.guava18.com.google.common.hash.Hashing.murmur3_128) RichFlatMapFunction(org.apache.flink.api.common.functions.RichFlatMapFunction) BuildOnlineLdaModel(com.alibaba.alink.operator.common.clustering.lda.BuildOnlineLdaModel) TableUtil(com.alibaba.alink.common.utils.TableUtil) HashMap(java.util.HashMap) DocCountVectorizerTrainBatchOp(com.alibaba.alink.operator.batch.nlp.DocCountVectorizerTrainBatchOp) BuildEmLdaModel(com.alibaba.alink.operator.common.clustering.lda.BuildEmLdaModel) ArrayList(java.util.ArrayList) Partitioner(org.apache.flink.api.common.functions.Partitioner) DocCountVectorizerModelData(com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData) Lists(com.google.common.collect.Lists) RichMapFunction(org.apache.flink.api.common.functions.RichMapFunction) Collector(org.apache.flink.util.Collector) DenseMatrix(com.alibaba.alink.common.linalg.DenseMatrix) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) EmLogLikelihood(com.alibaba.alink.operator.common.clustering.lda.EmLogLikelihood) Types(org.apache.flink.api.common.typeinfo.Types) LdaModelData(com.alibaba.alink.operator.common.clustering.LdaModelData) Configuration(org.apache.flink.configuration.Configuration) RowCollector(com.alibaba.alink.common.utils.RowCollector) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) StatisticsHelper(com.alibaba.alink.operator.common.statistics.StatisticsHelper) RandomDataGenerator(org.apache.commons.math3.random.RandomDataGenerator) FeatureType(com.alibaba.alink.operator.common.nlp.FeatureType) SparseVector(com.alibaba.alink.common.linalg.SparseVector) Comparator(java.util.Comparator) Params(org.apache.flink.ml.api.misc.param.Params) DataSet(org.apache.flink.api.java.DataSet) NotImplementedException(org.apache.commons.lang.NotImplementedException) SparseVector(com.alibaba.alink.common.linalg.SparseVector) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Comparator(java.util.Comparator) DocCountVectorizerModelData(com.alibaba.alink.operator.common.nlp.DocCountVectorizerModelData) List(java.util.List) ArrayList(java.util.ArrayList) Vector(com.alibaba.alink.common.linalg.Vector) SparseVector(com.alibaba.alink.common.linalg.SparseVector) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) NotImplementedException(org.apache.commons.lang.NotImplementedException) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Row(org.apache.flink.types.Row)

Example 19 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class GraphEmbedding method trans2Index.

/**
 * Transform vertex with index
 * vocab, schema{NODE_COL:originalType, NODE_INDEX_COL:long}
 * indexedGraph,schema {SOURCE_COL:long, TARGET_COL:long, WEIGHT_COL:double}
 * indexWithType, if in2 is not null, then returned, schema {NODE_INDEX_COL:long, NODE_TYPE_COL:string}
 *
 * @param in1    is graph data
 * @param in2    is the vertexList with vertexType, optional
 * @param params user inputted parameters
 * @return
 */
public static BatchOperator[] trans2Index(BatchOperator in1, BatchOperator in2, Params params) {
    String sourceColName = params.get(HasSourceCol.SOURCE_COL);
    String targetColName = params.get(HasTargetCol.TARGET_COL);
    String clause;
    if (params.contains(HasWeightCol.WEIGHT_COL)) {
        String weightColName = params.get(HasWeightCol.WEIGHT_COL);
        clause = "`" + sourceColName + "`, `" + targetColName + "`, `" + weightColName + "`";
    } else {
        clause = "`" + sourceColName + "`, `" + targetColName + "`, 1.0";
    }
    BatchOperator in = in1.select(clause).as(SOURCE_COL + ", " + TARGET_COL + ", " + WEIGHT_COL);
    // count the times that all the words appear in the edges.
    BatchOperator wordCnt = WordCountUtil.count(new UnionAllBatchOp().setMLEnvironmentId(in1.getMLEnvironmentId()).linkFrom(in.select(SOURCE_COL), in.select(TARGET_COL)).as(NODE_COL), NODE_COL);
    // name each vocab with its index.
    BatchOperator vocab = WordCountUtil.randomIndexVocab(wordCnt, 0).select(WordCountUtil.WORD_COL_NAME + " AS " + NODE_COL + ", " + WordCountUtil.INDEX_COL_NAME + " AS " + NODE_INDEX_COL);
    // transform input and vocab to dataSet<Tuple>
    DataSet<Tuple> inDataSet = in.getDataSet().map(new MapFunction<Row, Tuple3<Comparable, Comparable, Comparable>>() {

        private static final long serialVersionUID = 8473819294214049730L;

        @Override
        public Tuple3<Comparable, Comparable, Comparable> map(Row value) throws Exception {
            return Tuple3.of((Comparable) value.getField(0), (Comparable) value.getField(1), (Comparable) value.getField(2));
        }
    });
    DataSet<Tuple2> vocabDataSet = vocab.getDataSet().map(new MapFunction<Row, Tuple2<Comparable, Long>>() {

        private static final long serialVersionUID = 7241884458236714150L;

        @Override
        public Tuple2<Comparable, Long> map(Row value) throws Exception {
            return Tuple2.of((Comparable) value.getField(0), (Long) value.getField(1));
        }
    });
    // join operation
    DataSet<Tuple> joinWithSourceColTuple = HackBatchOpJoin.join(inDataSet, vocabDataSet, 0, 0, new int[][] { { 1, 1 }, { 0, 1 }, { 0, 2 } });
    DataSet<Tuple> indexGraphTuple = HackBatchOpJoin.join(joinWithSourceColTuple, vocabDataSet, 1, 0, new int[][] { { 0, 0 }, { 1, 1 }, { 0, 2 } });
    // build batchOperator
    TypeInformation<?>[] inTypes = in.getColTypes();
    TypeInformation<?>[] vocabTypes = vocab.getColTypes();
    BatchOperator indexedGraphBatchOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), indexGraphTuple.map(new MapFunction<Tuple, Row>() {

        private static final long serialVersionUID = -5386264086074581748L;

        @Override
        public Row map(Tuple value) throws Exception {
            Row res = new Row(3);
            res.setField(0, value.getField(0));
            res.setField(1, value.getField(1));
            res.setField(2, value.getField(2));
            return res;
        }
    }), new String[] { SOURCE_COL, TARGET_COL, WEIGHT_COL }, new TypeInformation<?>[] { vocabTypes[1], vocabTypes[1], inTypes[2] }));
    if (null == in2) {
        return new BatchOperator[] { vocab, indexedGraphBatchOp };
    } else {
        BatchOperator in2Selected = in2.select("`" + params.get(HasVertexCol.VERTEX_COL) + "`, `" + params.get(HasTypeCol.TYPE_COL) + "`").as(TEMP_NODE_COL + ", " + NODE_TYPE_COL);
        TypeInformation<?>[] types = new TypeInformation[2];
        types[1] = in2.getColTypes()[TableUtil.findColIndex(in2.getSchema(), params.get(HasTypeCol.TYPE_COL))];
        types[0] = vocab.getColTypes()[TableUtil.findColIndex(vocab.getSchema(), NODE_INDEX_COL)];
        DataSet<Tuple> in2Tuple = in2Selected.getDataSet().map(new MapFunction<Row, Tuple2<Comparable, Comparable>>() {

            private static final long serialVersionUID = 3459700988499538679L;

            @Override
            public Tuple2<Comparable, Comparable> map(Row value) throws Exception {
                Tuple2<Comparable, Comparable> res = new Tuple2<>();
                res.setField(value.getField(0), 0);
                res.setField(value.getField(1), 1);
                return res;
            }
        });
        DataSet<Row> indexWithTypeRow = HackBatchOpJoin.join(in2Tuple, vocabDataSet, 0, 0, new int[][] { { 1, 1 }, { 0, 1 } }).map(new MapFunction<Tuple, Row>() {

            private static final long serialVersionUID = -5747375637774394150L;

            @Override
            public Row map(Tuple value) throws Exception {
                int length = value.getArity();
                Row res = new Row(length);
                for (int i = 0; i < length; i++) {
                    res.setField(i, value.getField(i));
                }
                return res;
            }
        });
        BatchOperator indexWithType = new TableSourceBatchOp(DataSetConversionUtil.toTable(in.getMLEnvironmentId(), indexWithTypeRow, new String[] { NODE_INDEX_COL, NODE_TYPE_COL }, types)).setMLEnvironmentId(in.getMLEnvironmentId());
        return new BatchOperator[] { vocab, indexedGraphBatchOp, indexWithType };
    }
}
Also used : TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) UnionAllBatchOp(com.alibaba.alink.operator.batch.sql.UnionAllBatchOp) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Row(org.apache.flink.types.Row) Tuple(org.apache.flink.api.java.tuple.Tuple)

Example 20 with TableSourceBatchOp

use of com.alibaba.alink.operator.batch.source.TableSourceBatchOp in project Alink by alibaba.

the class ModelExporterUtils method postOrderDeserialize.

private static <T extends PipelineStageBase<?>> List<T> postOrderDeserialize(StageNode[] stages, BatchOperator<?> unpacked, final TableSchema schema, final int offset) {
    if (stages == null || stages.length == 0) {
        return new ArrayList<>();
    }
    final long[] id = new long[] { stages.length - 1 };
    final BatchOperator<?>[] deserialized = new BatchOperator<?>[] { unpacked };
    List<T> result = new ArrayList<>();
    Consumer<Integer> deserializer = lp -> {
        StageNode stageNode = stages[lp];
        try {
            if (stageNode.identifier != null) {
                stageNode.stage = (PipelineStageBase<?>) Class.forName(stageNode.identifier).getConstructor(Params.class).newInstance(stageNode.params);
            }
        } catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException ex) {
            throw new IllegalArgumentException(ex);
        }
        // leaf node.
        if (stageNode.children == null) {
            if (stageNode.parent >= 0 && stageNode.schemaIndices != null && stageNode.colNames != null) {
                final long localId = id[0];
                final int[] localSchemaIndices = stageNode.schemaIndices;
                BatchOperator<?> model = new TableSourceBatchOp(DataSetConversionUtil.toTable(deserialized[0].getMLEnvironmentId(), deserialized[0].getDataSet().filter(new FilterFunction<Row>() {

                    private static final long serialVersionUID = 355683133177055891L;

                    @Override
                    public boolean filter(Row value) {
                        return value.getField(0).equals(localId);
                    }
                }).map(new MapFunction<Row, Row>() {

                    private static final long serialVersionUID = -4286266312978550037L;

                    @Override
                    public Row map(Row value) throws Exception {
                        Row ret = new Row(localSchemaIndices.length);
                        for (int i = 0; i < localSchemaIndices.length; ++i) {
                            ret.setField(i, value.getField(localSchemaIndices[i] + offset));
                        }
                        return ret;
                    }
                }).returns(new RowTypeInfo(stageNode.types)), new TableSchema(stageNode.colNames, stageNode.types))).setMLEnvironmentId(deserialized[0].getMLEnvironmentId());
                ((ModelBase<?>) stageNode.stage).setModelData(model);
                deserialized[0] = new TableSourceBatchOp(DataSetConversionUtil.toTable(deserialized[0].getMLEnvironmentId(), deserialized[0].getDataSet().filter(new FilterFunction<Row>() {

                    private static final long serialVersionUID = -2803966833769030531L;

                    @Override
                    public boolean filter(Row value) {
                        return !value.getField(0).equals(localId);
                    }
                }), schema)).setMLEnvironmentId(deserialized[0].getMLEnvironmentId());
            }
        } else {
            List<T> pipelineStageBases = new ArrayList<>();
            for (int i = 0; i < stageNode.children.length; ++i) {
                pipelineStageBases.add((T) stages[stageNode.children[i]].stage);
            }
            if (stageNode.stage == null) {
                result.addAll(pipelineStageBases);
                return;
            }
            if (stageNode.stage instanceof Pipeline) {
                stageNode.stage = new Pipeline(pipelineStageBases.toArray(new PipelineStageBase<?>[0]));
            } else if (stageNode.stage instanceof PipelineModel) {
                stageNode.stage = new PipelineModel(pipelineStageBases.toArray(new TransformerBase<?>[0]));
            } else {
                throw new IllegalArgumentException("Unsupported stage.");
            }
        }
        id[0]--;
    };
    postOrder(stages, deserializer);
    return result;
}
Also used : Arrays(java.util.Arrays) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Tuple2(org.apache.flink.api.java.tuple.Tuple2) MLEnvironmentFactory(com.alibaba.alink.common.MLEnvironmentFactory) AkStream(com.alibaba.alink.common.io.filesystem.AkStream) JsonConverter(com.alibaba.alink.common.utils.JsonConverter) RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) MapFunction(org.apache.flink.api.common.functions.MapFunction) ModelMapper(com.alibaba.alink.common.mapper.ModelMapper) ComboModelMapper(com.alibaba.alink.common.mapper.ComboModelMapper) AkUtils(com.alibaba.alink.common.io.filesystem.AkUtils) DataSet(org.apache.flink.api.java.DataSet) RecommenderUtil(com.alibaba.alink.pipeline.recommendation.RecommenderUtil) Map(java.util.Map) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) TableSchema(org.apache.flink.table.api.TableSchema) Preconditions(org.apache.flink.util.Preconditions) Collectors(java.util.stream.Collectors) InvocationTargetException(java.lang.reflect.InvocationTargetException) TypeReference(org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference) FilterFunction(org.apache.flink.api.common.functions.FilterFunction) List(java.util.List) DataSetConversionUtil(com.alibaba.alink.common.utils.DataSetConversionUtil) ModelStreamScanParams(com.alibaba.alink.params.ModelStreamScanParams) ComboMapper(com.alibaba.alink.common.mapper.ComboMapper) Row(org.apache.flink.types.Row) IntStream(java.util.stream.IntStream) TableUtil(com.alibaba.alink.common.utils.TableUtil) Lists(org.apache.flink.shaded.guava18.com.google.common.collect.Lists) MapperChain(com.alibaba.alink.common.mapper.MapperChain) BaseRecommender(com.alibaba.alink.pipeline.recommendation.BaseRecommender) ArrayUtils(org.apache.commons.lang3.ArrayUtils) HashMap(java.util.HashMap) PipelineModelMapper.getExtendModelSchema(com.alibaba.alink.common.mapper.PipelineModelMapper.getExtendModelSchema) ArrayList(java.util.ArrayList) FileProcFunction(com.alibaba.alink.common.io.filesystem.AkUtils.FileProcFunction) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) CsvUtil(com.alibaba.alink.operator.common.io.csv.CsvUtil) LinkedList(java.util.LinkedList) FlinkTypeConverter(com.alibaba.alink.operator.common.io.types.FlinkTypeConverter) Types(org.apache.flink.api.common.typeinfo.Types) IOException(java.io.IOException) AkReader(com.alibaba.alink.common.io.filesystem.AkStream.AkReader) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) Consumer(java.util.function.Consumer) FilePath(com.alibaba.alink.common.io.filesystem.FilePath) PipelineModelMapper(com.alibaba.alink.common.mapper.PipelineModelMapper) Comparator(java.util.Comparator) Params(org.apache.flink.ml.api.misc.param.Params) Collections(java.util.Collections) Mapper(com.alibaba.alink.common.mapper.Mapper) FilterFunction(org.apache.flink.api.common.functions.FilterFunction) TableSchema(org.apache.flink.table.api.TableSchema) ArrayList(java.util.ArrayList) MapFunction(org.apache.flink.api.common.functions.MapFunction) RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) Row(org.apache.flink.types.Row)

Aggregations

TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)39 Row (org.apache.flink.types.Row)29 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)22 Test (org.junit.Test)18 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)12 TableSourceStreamOp (com.alibaba.alink.operator.stream.source.TableSourceStreamOp)10 Params (org.apache.flink.ml.api.misc.param.Params)10 List (java.util.List)8 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)8 StreamOperator (com.alibaba.alink.operator.stream.StreamOperator)6 ArrayList (java.util.ArrayList)6 TableSchema (org.apache.flink.table.api.TableSchema)6 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)5 Comparator (java.util.Comparator)4 HashMap (java.util.HashMap)4 MapFunction (org.apache.flink.api.common.functions.MapFunction)4 DataSet (org.apache.flink.api.java.DataSet)4 Mapper (com.alibaba.alink.common.mapper.Mapper)3 ModelMapper (com.alibaba.alink.common.mapper.ModelMapper)3 PipelineModelMapper (com.alibaba.alink.common.mapper.PipelineModelMapper)3