Search in sources :

Example 1 with TFTableModelTrainBatchOp

use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.

the class DeepARTrainBatchOp method linkFrom.

@Override
public DeepARTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> input = checkAndGetFirst(inputs);
    BatchOperator<?> preprocessed = new DeepARPreProcessBatchOp(getParams().clone()).setOutputCols("tensor", "v", "y").setMLEnvironmentId(getMLEnvironmentId()).linkFrom(input);
    Map<String, Object> modelConfig = new HashMap<>();
    modelConfig.put("window", getWindow());
    modelConfig.put("stride", getStride());
    Map<String, String> userParams = new HashMap<>();
    userParams.put("tensorCol", "tensor");
    userParams.put("labelCol", "y");
    userParams.put("batch_size", String.valueOf(getBatchSize()));
    userParams.put("num_epochs", String.valueOf(getNumEpochs()));
    userParams.put("model_config", JsonConverter.toJson(modelConfig));
    TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp(getParams().clone()).setSelectedCols("tensor", "y").setUserFiles(new String[] { "res:///tf_algos/deepar_entry.py" }).setMainScriptFile("res:///tf_algos/deepar_entry.py").setUserParams(JsonConverter.toJson(userParams)).setMLEnvironmentId(getMLEnvironmentId()).linkFrom(preprocessed);
    final Params params = getParams();
    setOutput(tfTableModelTrainBatchOp.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, Row>() {

        private transient TimeFrequency frequency;

        @Override
        public void open(Configuration parameters) throws Exception {
            frequency = getRuntimeContext().getBroadcastVariableWithInitializer("frequency", new BroadcastVariableInitializer<TimeFrequency, TimeFrequency>() {

                @Override
                public TimeFrequency initializeBroadcastVariable(Iterable<TimeFrequency> data) {
                    return data.iterator().next();
                }
            });
        }

        @Override
        public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
            List<Row> all = new ArrayList<>();
            for (Row val : values) {
                all.add(val);
            }
            new DeepARModelDataConverter().save(new DeepARModelData(params.clone().set(HasTimeFrequency.TIME_FREQUENCY, frequency), all), out);
        }
    }).withBroadcastSet(preprocessed.getSideOutput(0).getDataSet().map(new MapFunction<Row, TimeFrequency>() {

        @Override
        public TimeFrequency map(Row value) throws Exception {
            return (TimeFrequency) value.getField(0);
        }
    }), "frequency"), new DeepARModelDataConverter().getModelSchema());
    return this;
}
Also used : Configuration(org.apache.flink.configuration.Configuration) HashMap(java.util.HashMap) DeepARTrainParams(com.alibaba.alink.params.timeseries.DeepARTrainParams) DeepARPreProcessParams(com.alibaba.alink.params.timeseries.DeepARPreProcessParams) Params(org.apache.flink.ml.api.misc.param.Params) DeepARModelData(com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter.DeepARModelData) DeepARModelDataConverter(com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter) HasTimeFrequency(com.alibaba.alink.params.timeseries.HasTimeFrequency) TimeFrequency(com.alibaba.alink.params.timeseries.HasTimeFrequency.TimeFrequency) ArrayList(java.util.ArrayList) List(java.util.List) TFTableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp) Row(org.apache.flink.types.Row)

Example 2 with TFTableModelTrainBatchOp

use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.

the class EasyTransferConfigTrainBatchOp method linkFrom.

@Override
public EasyTransferConfigTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = inputs[0];
    Params params = getParams();
    if (null != getSelectedCols()) {
        in = in.select(getSelectedCols());
    }
    String pythonEnv = getPythonEnv();
    if (StringUtils.isNullOrWhitespaceOnly(pythonEnv)) {
        pythonEnv = DLEnvConfig.getTF115DefaultPythonEnv();
    }
    Map<String, String> userParams = JsonConverter.fromJson(getUserParams(), new TypeToken<Map<String, String>>() {
    }.getType());
    userParams.put("config_json", getConfigJson());
    ExternalFilesConfig externalFilesConfig = params.contains(HasUserFiles.USER_FILES) ? ExternalFilesConfig.fromJson(params.get(HasUserFiles.USER_FILES)) : new ExternalFilesConfig();
    externalFilesConfig.addFilePaths(resPyFiles);
    TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp().setUserFiles(externalFilesConfig).setMainScriptFile(mainScriptFileName).setNumWorkers(getNumWorkers()).setNumPSs(getNumPSs()).setUserParams(JsonConverter.toJson(userParams)).setPythonEnv(pythonEnv).setIntraOpParallelism(getIntraOpParallelism()).setMLEnvironmentId(getMLEnvironmentId());
    BatchOperator<?>[] tfInputs;
    tfInputs = new BatchOperator<?>[inputs.length];
    tfInputs[0] = in;
    System.arraycopy(inputs, 1, tfInputs, 1, inputs.length - 1);
    BatchOperator<?> tfModel = tfTableModelTrainBatchOp.linkFrom(tfInputs);
    this.setOutputTable(tfModel.getOutputTable());
    return this;
}
Also used : TypeToken(com.google.gson.reflect.TypeToken) EasyTransferConfigTrainParams(com.alibaba.alink.params.tensorflow.bert.EasyTransferConfigTrainParams) Params(org.apache.flink.ml.api.misc.param.Params) TFTableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator)

Example 3 with TFTableModelTrainBatchOp

use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.

the class LSTNetTrainBatchOp method linkFrom.

@Override
public LSTNetTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> input = checkAndGetFirst(inputs);
    BatchOperator<?> preprocessed = new LSTNetPreProcessBatchOp(getParams().clone()).setOutputCols("tensor", "y").setMLEnvironmentId(getMLEnvironmentId()).linkFrom(input);
    Map<String, Object> modelConfig = new HashMap<>();
    modelConfig.put("window", getWindow());
    modelConfig.put("horizon", getHorizon());
    Map<String, String> userParams = new HashMap<>();
    userParams.put("tensorCol", "tensor");
    userParams.put("labelCol", "y");
    userParams.put("batch_size", String.valueOf(getBatchSize()));
    userParams.put("num_epochs", String.valueOf(getNumEpochs()));
    userParams.put("model_config", JsonConverter.toJson(modelConfig));
    TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp(getParams().clone()).setSelectedCols("tensor", "y").setUserFiles(new String[] { "res:///tf_algos/lstnet_entry.py" }).setMainScriptFile("res:///tf_algos/lstnet_entry.py").setUserParams(JsonConverter.toJson(userParams)).linkFrom(preprocessed).setMLEnvironmentId(getMLEnvironmentId());
    setOutputTable(tfTableModelTrainBatchOp.getOutputTable());
    return this;
}
Also used : HashMap(java.util.HashMap) TFTableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp)

Example 4 with TFTableModelTrainBatchOp

use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.

the class TFTableModelPredictStreamOpTest method test.

@Category(DLTest.class)
@Test
public void test() throws Exception {
    int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(3);
    BatchOperator<?> source = new RandomTableSourceBatchOp().setNumRows(100L).setNumCols(10);
    String[] colNames = source.getColNames();
    source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
    String label = "label";
    StreamOperator<?> streamSource = new RandomTableSourceStreamOp().setNumCols(10).setMaxRows(100L);
    Map<String, Object> userParams = new HashMap<>();
    userParams.put("featureCols", JsonConverter.toJson(colNames));
    userParams.put("labelCol", label);
    userParams.put("batch_size", 16);
    userParams.put("num_epochs", 1);
    TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp().setUserFiles(new String[] { "res:///tf_dnn_train.py" }).setMainScriptFile("res:///tf_dnn_train.py").setUserParams(JsonConverter.toJson(userParams)).setNumWorkers(2).setNumPSs(1).linkFrom(source);
    TFTableModelPredictStreamOp tfTableModelPredictStreamOp = new TFTableModelPredictStreamOp(tfTableModelTrainBatchOp).setOutputSchemaStr("logits double").setOutputSignatureDefs(new String[] { "logits" }).setSignatureDefKey("predict").setSelectedCols(colNames).linkFrom(streamSource);
    tfTableModelPredictStreamOp.print();
    StreamOperator.execute();
    StreamOperator.setParallelism(savedStreamParallelism);
}
Also used : HashMap(java.util.HashMap) TFTableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp) RandomTableSourceBatchOp(com.alibaba.alink.operator.batch.source.RandomTableSourceBatchOp) RandomTableSourceStreamOp(com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp) Category(org.junit.experimental.categories.Category) DLTest(com.alibaba.alink.testutil.categories.DLTest) Test(org.junit.Test)

Aggregations

TFTableModelTrainBatchOp (com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp)4 HashMap (java.util.HashMap)3 Params (org.apache.flink.ml.api.misc.param.Params)2 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 RandomTableSourceBatchOp (com.alibaba.alink.operator.batch.source.RandomTableSourceBatchOp)1 DeepARModelDataConverter (com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter)1 DeepARModelData (com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter.DeepARModelData)1 RandomTableSourceStreamOp (com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp)1 EasyTransferConfigTrainParams (com.alibaba.alink.params.tensorflow.bert.EasyTransferConfigTrainParams)1 DeepARPreProcessParams (com.alibaba.alink.params.timeseries.DeepARPreProcessParams)1 DeepARTrainParams (com.alibaba.alink.params.timeseries.DeepARTrainParams)1 HasTimeFrequency (com.alibaba.alink.params.timeseries.HasTimeFrequency)1 TimeFrequency (com.alibaba.alink.params.timeseries.HasTimeFrequency.TimeFrequency)1 DLTest (com.alibaba.alink.testutil.categories.DLTest)1 TypeToken (com.google.gson.reflect.TypeToken)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Configuration (org.apache.flink.configuration.Configuration)1 Row (org.apache.flink.types.Row)1 Test (org.junit.Test)1