use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BaseDLTableModelTrainBatchOp method linkFrom.
@Override
public T linkFrom(BatchOperator<?>... inputs) {
initDLSystemParams();
BatchOperator<?> input = inputs[0];
if (null != getSelectedCols()) {
input = input.select(getSelectedCols());
}
input = new ShuffleBatchOp().linkFrom(input);
ExternalFilesConfig externalFiles = getUserFiles().addFilePaths(resPyFiles).addRenameMap(getMainScriptFile(), userMainScriptRename);
Map<String, String> algoParams = new HashMap<>();
String userParamsStr = getUserParams();
if (StringUtils.isNoneEmpty(userParamsStr)) {
TypeToken<Map<String, String>> typeToken = new TypeToken<Map<String, String>>() {
};
Map<String, String> userParams = new Gson().fromJson(userParamsStr, typeToken.getType());
userParams.forEach(algoParams::put);
}
DLLauncherBatchOp dlLauncherBatchOp = new DLLauncherBatchOp().setOutputSchemaStr(MODEL_SCHEMA_STR).setNumWorkers(getNumWorkers()).setNumPSs(numPss).setEntryFunc(entryFuncName).setPythonEnv(getPythonEnv()).setUserFiles(externalFiles).setMainScriptFile(mainScriptFileName).setUserParams(JsonConverter.toJson(algoParams)).setIntraOpParallelism(getIntraOpParallelism()).setMLEnvironmentId(getMLEnvironmentId());
BatchOperator<?>[] tfInputs = new BatchOperator<?>[inputs.length];
tfInputs[0] = input;
System.arraycopy(inputs, 1, tfInputs, 1, inputs.length - 1);
BatchOperator<?> tfModel = dlLauncherBatchOp.linkFrom(tfInputs);
this.setOutputTable(tfModel.getOutputTable());
return (T) this;
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextClassifierTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
String schemaStr = "label bigint, review string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setIgnoreFirstLine(true);
data = data.where("review is not null");
data = new ShuffleBatchOp().linkFrom(data);
BertTextClassifier classifier = new BertTextClassifier().setTextCol("review").setLabelCol("label").setNumEpochs(0.01).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
BertClassificationModel model = classifier.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextRegressorTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
String schemaStr = "label double, review string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setIgnoreFirstLine(true);
data = data.where("review is not null");
data = new ShuffleBatchOp().linkFrom(data);
BertTextRegressor regressor = new BertTextRegressor().setTextCol("review").setLabelCol("label").setNumEpochs(0.01).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setPredictionCol("pred");
BertRegressionModel model = regressor.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextRegressorTrainBatchOpTest method test.
@Category(DLTest.class)
@Test
public void test() throws Exception {
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(2);
String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
String schema = "label double, review string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setIgnoreFirstLine(true);
data = data.where("review is not null");
data = new ShuffleBatchOp().linkFrom(data);
Map<String, Map<String, Object>> customConfig = new HashMap<>();
customConfig.put("train_config", ImmutableMap.of("optimizer_config", ImmutableMap.of("learning_rate", 0.01)));
BertTextRegressorTrainBatchOp train = new BertTextRegressorTrainBatchOp().setTextCol("review").setLabelCol("label").setNumEpochs(0.05).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setCustomConfigJson(JsonConverter.toJson(customConfig)).linkFrom(data);
Assert.assertTrue(train.count() > 1);
BatchOperator.setParallelism(savedParallelism);
}
Aggregations