use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextPairClassifierTrainBatchOpTest method test.
@Category(DLTest.class)
@Test
public void test() throws Exception {
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(2);
String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
data = new ShuffleBatchOp().linkFrom(data);
Map<String, Map<String, Object>> customConfig = new HashMap<>();
customConfig.put("train_config", ImmutableMap.of("optimizer_config", ImmutableMap.of("learning_rate", 0.01)));
BertTextPairClassifierTrainBatchOp train = new BertTextPairClassifierTrainBatchOp().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setCustomConfigJson(JsonConverter.toJson(customConfig)).setBertModelName("Base-Uncased").linkFrom(data);
Assert.assertTrue(train.count() > 1);
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextPairRegressorTrainBatchOpTest method test.
@Category(DLTest.class)
@Test
public void test() throws Exception {
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(2);
String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
data = new ShuffleBatchOp().linkFrom(data);
Map<String, Map<String, Object>> customConfig = new HashMap<>();
customConfig.put("train_config", ImmutableMap.of("optimizer_config", ImmutableMap.of("learning_rate", 0.01)));
BertTextPairRegressorTrainBatchOp train = new BertTextPairRegressorTrainBatchOp().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setCustomConfigJson(JsonConverter.toJson(customConfig)).setBertModelName("Base-Uncased").linkFrom(data);
Assert.assertTrue(train.count() > 1);
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class NegativeItemSamplingBatchOp method linkFrom.
@Override
public NegativeItemSamplingBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> userItemPairs = inputs[0];
setMLEnvironmentId(userItemPairs.getMLEnvironmentId());
Preconditions.checkArgument(userItemPairs.getColNames().length == 2);
BatchOperator<?> distinctItems = userItemPairs.select(userItemPairs.getColNames()[1]).distinct();
negativeSampling(userItemPairs, distinctItems);
this.setOutputTable(this.link(new ShuffleBatchOp()).getOutputTable());
return this;
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextPairClassifierTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
data = new ShuffleBatchOp().linkFrom(data);
BertTextPairClassifier classifier = new BertTextPairClassifier().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
BertClassificationModel model = classifier.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.
the class BertTextPairRegressorTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
data = new ShuffleBatchOp().linkFrom(data);
BertTextPairRegressor regressor = new BertTextPairRegressor().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred");
BertRegressionModel model = regressor.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
Aggregations