Search in sources :

Example 1 with ConstructModelFlatMapFunction

use of com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelFlatMapFunction in project Alink by alibaba.

the class BaseKerasSequentialTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = inputs[0];
    Params params = getParams();
    TaskType taskType = params.get(HasTaskType.TASK_TYPE);
    boolean isReg = TaskType.REGRESSION.equals(taskType);
    String tensorCol = getTensorCol();
    String labelCol = getLabelCol();
    TypeInformation<?> labelType = TableUtil.findColTypeWithAssertAndHint(in.getSchema(), labelCol);
    DataSet<List<Object>> sortedLabels = null;
    BatchOperator<?> numLabelsOp = null;
    if (!isReg) {
        sortedLabels = in.select(labelCol).getDataSet().mapPartition(new MapPartitionFunction<Row, Object>() {

            @Override
            public void mapPartition(Iterable<Row> iterable, Collector<Object> collector) throws Exception {
                Set<Object> distinctValue = new HashSet<>();
                for (Row row : iterable) {
                    distinctValue.add(row.getField(0));
                }
                for (Object obj : distinctValue) {
                    collector.collect(obj);
                }
            }
        }).reduceGroup(new GroupReduceFunction<Object, List<Object>>() {

            @Override
            public void reduce(Iterable<Object> iterable, Collector<List<Object>> collector) throws Exception {
                Set<Object> distinctValue = new TreeSet<>();
                for (Object obj : iterable) {
                    distinctValue.add(obj);
                }
                collector.collect(new ArrayList<>(distinctValue));
            }
        });
        in = CommonUtils.mapLabelToIndex(in, labelCol, sortedLabels);
        numLabelsOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), sortedLabels.map(new CountLabelsMapFunction()), new String[] { "count" }, new TypeInformation<?>[] { Types.INT })).setMLEnvironmentId(getMLEnvironmentId());
    }
    Boolean removeCheckpointBeforeTraining = getRemoveCheckpointBeforeTraining();
    if (null == removeCheckpointBeforeTraining) {
        // default to clean checkpoint
        removeCheckpointBeforeTraining = true;
    }
    Map<String, Object> modelConfig = new HashMap<>();
    modelConfig.put("layers", getLayers());
    Map<String, String> userParams = new HashMap<>();
    if (removeCheckpointBeforeTraining) {
        userParams.put(DLConstants.REMOVE_CHECKPOINT_BEFORE_TRAINING, "true");
    }
    userParams.put("tensor_cols", JsonConverter.toJson(new String[] { tensorCol }));
    userParams.put("label_col", labelCol);
    userParams.put("label_type", "float");
    userParams.put("batch_size", String.valueOf(getBatchSize()));
    userParams.put("num_epochs", String.valueOf(getNumEpochs()));
    userParams.put("model_config", JsonConverter.toJson(modelConfig));
    userParams.put("optimizer", getOptimizer());
    if (!StringUtils.isNullOrWhitespaceOnly(getCheckpointFilePath())) {
        userParams.put("model_dir", getCheckpointFilePath());
    }
    ExecutionEnvironment env = MLEnvironmentFactory.get(getMLEnvironmentId()).getExecutionEnvironment();
    if (env.getParallelism() == 1) {
        userParams.put("ALINK:ONLY_ONE_WORKER", "true");
    }
    userParams.put("validation_split", String.valueOf(getValidationSplit()));
    userParams.put("save_best_only", String.valueOf(getSaveBestOnly()));
    userParams.put("best_exporter_metric", getBestMetric());
    userParams.put("save_checkpoints_epochs", String.valueOf(getSaveCheckpointsEpochs()));
    if (params.contains(BaseKerasSequentialTrainParams.SAVE_CHECKPOINTS_SECS)) {
        userParams.put("save_checkpoints_secs", String.valueOf(getSaveCheckpointsSecs()));
    }
    TF2TableModelTrainBatchOp trainBatchOp = new TF2TableModelTrainBatchOp(params).setSelectedCols(tensorCol, labelCol).setUserFiles(RES_PY_FILES).setMainScriptFile(MAIN_SCRIPT_FILE_NAME).setUserParams(JsonConverter.toJson(userParams)).setIntraOpParallelism(getIntraOpParallelism()).setNumPSs(getNumPSs()).setNumWorkers(getNumWorkers()).setPythonEnv(params.get(HasPythonEnv.PYTHON_ENV));
    if (isReg) {
        trainBatchOp = trainBatchOp.linkFrom(in);
    } else {
        trainBatchOp = trainBatchOp.linkFrom(in, numLabelsOp);
    }
    String tfOutputSignatureDef = getTfOutputSignatureDef(taskType);
    FlatMapOperator<Row, Row> constructModelFlatMapOperator = new NumSeqSourceBatchOp().setFrom(0).setTo(0).setMLEnvironmentId(getMLEnvironmentId()).getDataSet().flatMap(new ConstructModelFlatMapFunction(params, new String[] { tensorCol }, tfOutputSignatureDef, TF_OUTPUT_SIGNATURE_TYPE, null, true)).withBroadcastSet(trainBatchOp.getDataSet(), CommonUtils.TF_MODEL_BC_NAME);
    BatchOperator<?> modelOp;
    if (isReg) {
        DataSet<Row> modelDataSet = constructModelFlatMapOperator;
        modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelRegressionModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    } else {
        DataSet<Row> modelDataSet = constructModelFlatMapOperator.withBroadcastSet(sortedLabels, CommonUtils.SORTED_LABELS_BC_NAME);
        modelOp = new TableSourceBatchOp(DataSetConversionUtil.toTable(getMLEnvironmentId(), modelDataSet, new TFTableModelClassificationModelDataConverter(labelType).getModelSchema())).setMLEnvironmentId(getMLEnvironmentId());
    }
    this.setOutputTable(modelOp.getOutputTable());
    return (T) this;
}
Also used : ExecutionEnvironment(org.apache.flink.api.java.ExecutionEnvironment) ConstructModelFlatMapFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelFlatMapFunction) DataSet(org.apache.flink.api.java.DataSet) TableSourceBatchOp(com.alibaba.alink.operator.batch.source.TableSourceBatchOp) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) HasTaskType(com.alibaba.alink.params.dl.HasTaskType) Collector(org.apache.flink.util.Collector) NumSeqSourceBatchOp(com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp) GroupReduceFunction(org.apache.flink.api.common.functions.GroupReduceFunction) BaseKerasSequentialTrainParams(com.alibaba.alink.params.tensorflow.kerasequential.BaseKerasSequentialTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelRegressionModelDataConverter(com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter) TFTableModelClassificationModelDataConverter(com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter) TF2TableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp) CountLabelsMapFunction(com.alibaba.alink.operator.common.tensorflow.CommonUtils.CountLabelsMapFunction) Row(org.apache.flink.types.Row)

Aggregations

NumSeqSourceBatchOp (com.alibaba.alink.operator.batch.source.NumSeqSourceBatchOp)1 TableSourceBatchOp (com.alibaba.alink.operator.batch.source.TableSourceBatchOp)1 TF2TableModelTrainBatchOp (com.alibaba.alink.operator.batch.tensorflow.TF2TableModelTrainBatchOp)1 TFTableModelClassificationModelDataConverter (com.alibaba.alink.operator.common.classification.tensorflow.TFTableModelClassificationModelDataConverter)1 TFTableModelRegressionModelDataConverter (com.alibaba.alink.operator.common.regression.tensorflow.TFTableModelRegressionModelDataConverter)1 ConstructModelFlatMapFunction (com.alibaba.alink.operator.common.tensorflow.CommonUtils.ConstructModelFlatMapFunction)1 CountLabelsMapFunction (com.alibaba.alink.operator.common.tensorflow.CommonUtils.CountLabelsMapFunction)1 HasTaskType (com.alibaba.alink.params.dl.HasTaskType)1 BaseKerasSequentialTrainParams (com.alibaba.alink.params.tensorflow.kerasequential.BaseKerasSequentialTrainParams)1 GroupReduceFunction (org.apache.flink.api.common.functions.GroupReduceFunction)1 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)1 DataSet (org.apache.flink.api.java.DataSet)1 ExecutionEnvironment (org.apache.flink.api.java.ExecutionEnvironment)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Row (org.apache.flink.types.Row)1 Collector (org.apache.flink.util.Collector)1