Search in sources :

Example 1 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextPairClassifierTrainBatchOpTest method test.

@Category(DLTest.class)
@Test
public void test() throws Exception {
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(2);
    String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
    String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
    data = new ShuffleBatchOp().linkFrom(data);
    Map<String, Map<String, Object>> customConfig = new HashMap<>();
    customConfig.put("train_config", ImmutableMap.of("optimizer_config", ImmutableMap.of("learning_rate", 0.01)));
    BertTextPairClassifierTrainBatchOp train = new BertTextPairClassifierTrainBatchOp().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setCustomConfigJson(JsonConverter.toJson(customConfig)).setBertModelName("Base-Uncased").linkFrom(data);
    Assert.assertTrue(train.count() > 1);
    BatchOperator.setParallelism(savedParallelism);
}
Also used : HashMap(java.util.HashMap) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map) HashMap(java.util.HashMap) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Category(org.junit.experimental.categories.Category) DLTest(com.alibaba.alink.testutil.categories.DLTest) Test(org.junit.Test)

Example 2 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextPairRegressorTrainBatchOpTest method test.

@Category(DLTest.class)
@Test
public void test() throws Exception {
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(2);
    String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
    String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
    data = new ShuffleBatchOp().linkFrom(data);
    Map<String, Map<String, Object>> customConfig = new HashMap<>();
    customConfig.put("train_config", ImmutableMap.of("optimizer_config", ImmutableMap.of("learning_rate", 0.01)));
    BertTextPairRegressorTrainBatchOp train = new BertTextPairRegressorTrainBatchOp().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setCustomConfigJson(JsonConverter.toJson(customConfig)).setBertModelName("Base-Uncased").linkFrom(data);
    Assert.assertTrue(train.count() > 1);
    BatchOperator.setParallelism(savedParallelism);
}
Also used : HashMap(java.util.HashMap) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map) HashMap(java.util.HashMap) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Category(org.junit.experimental.categories.Category) DLTest(com.alibaba.alink.testutil.categories.DLTest) Test(org.junit.Test)

Example 3 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class NegativeItemSamplingBatchOp method linkFrom.

@Override
public NegativeItemSamplingBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> userItemPairs = inputs[0];
    setMLEnvironmentId(userItemPairs.getMLEnvironmentId());
    Preconditions.checkArgument(userItemPairs.getColNames().length == 2);
    BatchOperator<?> distinctItems = userItemPairs.select(userItemPairs.getColNames()[1]).distinct();
    negativeSampling(userItemPairs, distinctItems);
    this.setOutputTable(this.link(new ShuffleBatchOp()).getOutputTable());
    return this;
}
Also used : ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)

Example 4 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextPairClassifierTest method test.

@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
    String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
    data = new ShuffleBatchOp().linkFrom(data);
    BertTextPairClassifier classifier = new BertTextPairClassifier().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    BertClassificationModel model = classifier.fit(data);
    BatchOperator<?> predict = model.transform(data.firstN(300));
    predict.print();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Test(org.junit.Test)

Example 5 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextPairRegressorTest method test.

@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
    String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
    data = new ShuffleBatchOp().linkFrom(data);
    BertTextPairRegressor regressor = new BertTextPairRegressor().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred");
    BertRegressionModel model = regressor.fit(data);
    BatchOperator<?> predict = model.transform(data.firstN(300));
    predict.print();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Test(org.junit.Test)

Aggregations

ShuffleBatchOp (com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)9 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)7 Test (org.junit.Test)7 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)4 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)4 HashMap (java.util.HashMap)4 Map (java.util.Map)4 DLTest (com.alibaba.alink.testutil.categories.DLTest)3 ImmutableMap (com.google.common.collect.ImmutableMap)3 Category (org.junit.experimental.categories.Category)3 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 Gson (com.google.gson.Gson)1 TypeToken (com.google.gson.reflect.TypeToken)1