use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.
the class DeepARTrainBatchOp method linkFrom.
@Override
public DeepARTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> input = checkAndGetFirst(inputs);
BatchOperator<?> preprocessed = new DeepARPreProcessBatchOp(getParams().clone()).setOutputCols("tensor", "v", "y").setMLEnvironmentId(getMLEnvironmentId()).linkFrom(input);
Map<String, Object> modelConfig = new HashMap<>();
modelConfig.put("window", getWindow());
modelConfig.put("stride", getStride());
Map<String, String> userParams = new HashMap<>();
userParams.put("tensorCol", "tensor");
userParams.put("labelCol", "y");
userParams.put("batch_size", String.valueOf(getBatchSize()));
userParams.put("num_epochs", String.valueOf(getNumEpochs()));
userParams.put("model_config", JsonConverter.toJson(modelConfig));
TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp(getParams().clone()).setSelectedCols("tensor", "y").setUserFiles(new String[] { "res:///tf_algos/deepar_entry.py" }).setMainScriptFile("res:///tf_algos/deepar_entry.py").setUserParams(JsonConverter.toJson(userParams)).setMLEnvironmentId(getMLEnvironmentId()).linkFrom(preprocessed);
final Params params = getParams();
setOutput(tfTableModelTrainBatchOp.getDataSet().reduceGroup(new RichGroupReduceFunction<Row, Row>() {
private transient TimeFrequency frequency;
@Override
public void open(Configuration parameters) throws Exception {
frequency = getRuntimeContext().getBroadcastVariableWithInitializer("frequency", new BroadcastVariableInitializer<TimeFrequency, TimeFrequency>() {
@Override
public TimeFrequency initializeBroadcastVariable(Iterable<TimeFrequency> data) {
return data.iterator().next();
}
});
}
@Override
public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
List<Row> all = new ArrayList<>();
for (Row val : values) {
all.add(val);
}
new DeepARModelDataConverter().save(new DeepARModelData(params.clone().set(HasTimeFrequency.TIME_FREQUENCY, frequency), all), out);
}
}).withBroadcastSet(preprocessed.getSideOutput(0).getDataSet().map(new MapFunction<Row, TimeFrequency>() {
@Override
public TimeFrequency map(Row value) throws Exception {
return (TimeFrequency) value.getField(0);
}
}), "frequency"), new DeepARModelDataConverter().getModelSchema());
return this;
}
use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.
the class EasyTransferConfigTrainBatchOp method linkFrom.
@Override
public EasyTransferConfigTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = inputs[0];
Params params = getParams();
if (null != getSelectedCols()) {
in = in.select(getSelectedCols());
}
String pythonEnv = getPythonEnv();
if (StringUtils.isNullOrWhitespaceOnly(pythonEnv)) {
pythonEnv = DLEnvConfig.getTF115DefaultPythonEnv();
}
Map<String, String> userParams = JsonConverter.fromJson(getUserParams(), new TypeToken<Map<String, String>>() {
}.getType());
userParams.put("config_json", getConfigJson());
ExternalFilesConfig externalFilesConfig = params.contains(HasUserFiles.USER_FILES) ? ExternalFilesConfig.fromJson(params.get(HasUserFiles.USER_FILES)) : new ExternalFilesConfig();
externalFilesConfig.addFilePaths(resPyFiles);
TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp().setUserFiles(externalFilesConfig).setMainScriptFile(mainScriptFileName).setNumWorkers(getNumWorkers()).setNumPSs(getNumPSs()).setUserParams(JsonConverter.toJson(userParams)).setPythonEnv(pythonEnv).setIntraOpParallelism(getIntraOpParallelism()).setMLEnvironmentId(getMLEnvironmentId());
BatchOperator<?>[] tfInputs;
tfInputs = new BatchOperator<?>[inputs.length];
tfInputs[0] = in;
System.arraycopy(inputs, 1, tfInputs, 1, inputs.length - 1);
BatchOperator<?> tfModel = tfTableModelTrainBatchOp.linkFrom(tfInputs);
this.setOutputTable(tfModel.getOutputTable());
return this;
}
use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.
the class LSTNetTrainBatchOp method linkFrom.
@Override
public LSTNetTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> input = checkAndGetFirst(inputs);
BatchOperator<?> preprocessed = new LSTNetPreProcessBatchOp(getParams().clone()).setOutputCols("tensor", "y").setMLEnvironmentId(getMLEnvironmentId()).linkFrom(input);
Map<String, Object> modelConfig = new HashMap<>();
modelConfig.put("window", getWindow());
modelConfig.put("horizon", getHorizon());
Map<String, String> userParams = new HashMap<>();
userParams.put("tensorCol", "tensor");
userParams.put("labelCol", "y");
userParams.put("batch_size", String.valueOf(getBatchSize()));
userParams.put("num_epochs", String.valueOf(getNumEpochs()));
userParams.put("model_config", JsonConverter.toJson(modelConfig));
TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp(getParams().clone()).setSelectedCols("tensor", "y").setUserFiles(new String[] { "res:///tf_algos/lstnet_entry.py" }).setMainScriptFile("res:///tf_algos/lstnet_entry.py").setUserParams(JsonConverter.toJson(userParams)).linkFrom(preprocessed).setMLEnvironmentId(getMLEnvironmentId());
setOutputTable(tfTableModelTrainBatchOp.getOutputTable());
return this;
}
use of com.alibaba.alink.operator.batch.tensorflow.TFTableModelTrainBatchOp in project Alink by alibaba.
the class TFTableModelPredictStreamOpTest method test.
@Category(DLTest.class)
@Test
public void test() throws Exception {
int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(3);
BatchOperator<?> source = new RandomTableSourceBatchOp().setNumRows(100L).setNumCols(10);
String[] colNames = source.getColNames();
source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
String label = "label";
StreamOperator<?> streamSource = new RandomTableSourceStreamOp().setNumCols(10).setMaxRows(100L);
Map<String, Object> userParams = new HashMap<>();
userParams.put("featureCols", JsonConverter.toJson(colNames));
userParams.put("labelCol", label);
userParams.put("batch_size", 16);
userParams.put("num_epochs", 1);
TFTableModelTrainBatchOp tfTableModelTrainBatchOp = new TFTableModelTrainBatchOp().setUserFiles(new String[] { "res:///tf_dnn_train.py" }).setMainScriptFile("res:///tf_dnn_train.py").setUserParams(JsonConverter.toJson(userParams)).setNumWorkers(2).setNumPSs(1).linkFrom(source);
TFTableModelPredictStreamOp tfTableModelPredictStreamOp = new TFTableModelPredictStreamOp(tfTableModelTrainBatchOp).setOutputSchemaStr("logits double").setOutputSignatureDefs(new String[] { "logits" }).setSignatureDefKey("predict").setSelectedCols(colNames).linkFrom(streamSource);
tfTableModelPredictStreamOp.print();
StreamOperator.execute();
StreamOperator.setParallelism(savedStreamParallelism);
}
Aggregations