Search in sources :

Example 1 with TimeFrequency

use of com.alibaba.alink.params.timeseries.HasTimeFrequency.TimeFrequency in project Alink by alibaba.

the class DeepARTrainBatchOp method linkFrom.

@Override
public DeepARTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> input = checkAndGetFirst(inputs);
    BatchOperator<?> preprocessed = new DeepARPreProcessBatchOp(getParams().clone()).setOutputCols("tensor", "v", "y").setMLEnvironmentId(getMLEnvironmentId()).linkFrom(input);
    Map<String, Object> modelConfig = new HashMap<>();
    modelConfig.put("window", getWindow());
    modelConfig.put("stride", getStride());
    Map<String, String> userParams = new HashMap<>();
    userParams.put("tensorCol", "tensor");
    userParams.put("labelCol", "y");
    userParams.put("batch_size", String.valueOf(getBatchSize()));
    userParams.put("num_epochs", String.valueOf(getNumEpochs()));
    userParams.put("model_config", JsonConverter.toJson(modelConfig));
    TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp(getParams().clone()).setSelectedCols("tensor", "y").setUserFiles(new String[] { "res:///tf_algos/deepar_entry.py" }).setMainScriptFile("res:///tf_algos/deepar_entry.py").setUserParams(JsonConverter.toJson(userParams)).setMLEnvironmentId(getMLEnvironmentId()).linkFrom(preprocessed);
    final Params params = getParams();
    setOutput(tfTableModelTrainBatchOp.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, Row>() {

        private transient TimeFrequency frequency;

        @Override
        public void open(Configuration parameters) throws Exception {
            frequency = getRuntimeContext().getBroadcastVariableWithInitializer("frequency", new BroadcastVariableInitializer<TimeFrequency, TimeFrequency>() {

                @Override
                public TimeFrequency initializeBroadcastVariable(Iterable<TimeFrequency> data) {
                    return data.iterator().next();
                }
            });
        }

        @Override
        public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
            List<Row> all = new ArrayList<>();
            for (Row val : values) {
                all.add(val);
            }
            new DeepARModelDataConverter().save(new DeepARModelData(params.clone().set(HasTimeFrequency.TIME_FREQUENCY, frequency), all), out);
        }
    }).withBroadcastSet(preprocessed.getSideOutput(0).getDataSet().map(new MapFunction<Row, TimeFrequency>() {

        @Override
        public TimeFrequency map(Row value) throws Exception {
            return (TimeFrequency) value.getField(0);
        }
    }), "frequency"), new DeepARModelDataConverter().getModelSchema());
    return this;
}
Also used : Configuration(org.apache.flink.configuration.Configuration) HashMap(java.util.HashMap) DeepARTrainParams(com.alibaba.alink.params.timeseries.DeepARTrainParams) DeepARPreProcessParams(com.alibaba.alink.params.timeseries.DeepARPreProcessParams) Params(org.apache.flink.ml.api.misc.param.Params) DeepARModelData(com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter.DeepARModelData) DeepARModelDataConverter(com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter) HasTimeFrequency(com.alibaba.alink.params.timeseries.HasTimeFrequency) TimeFrequency(com.alibaba.alink.params.timeseries.HasTimeFrequency.TimeFrequency) ArrayList(java.util.ArrayList) List(java.util.List) TFTableModelTrainBatchOp(com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp) Row(org.apache.flink.types.Row)

Aggregations

TFTableModelTrainBatchOp (com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp)1 DeepARModelDataConverter (com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter)1 DeepARModelData (com.alibaba.alink.operator.common.timeseries.DeepARModelDataConverter.DeepARModelData)1 DeepARPreProcessParams (com.alibaba.alink.params.timeseries.DeepARPreProcessParams)1 DeepARTrainParams (com.alibaba.alink.params.timeseries.DeepARTrainParams)1 HasTimeFrequency (com.alibaba.alink.params.timeseries.HasTimeFrequency)1 TimeFrequency (com.alibaba.alink.params.timeseries.HasTimeFrequency.TimeFrequency)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Configuration (org.apache.flink.configuration.Configuration)1 Params (org.apache.flink.ml.api.misc.param.Params)1 Row (org.apache.flink.types.Row)1