Search in sources :

Example 6 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BaseDLTableModelTrainBatchOp method linkFrom.

@Override
public T linkFrom(BatchOperator<?>... inputs) {
    initDLSystemParams();
    BatchOperator<?> input = inputs[0];
    if (null != getSelectedCols()) {
        input = input.select(getSelectedCols());
    }
    input = new ShuffleBatchOp().linkFrom(input);
    ExternalFilesConfig externalFiles = getUserFiles().addFilePaths(resPyFiles).addRenameMap(getMainScriptFile(), userMainScriptRename);
    Map<String, String> algoParams = new HashMap<>();
    String userParamsStr = getUserParams();
    if (StringUtils.isNoneEmpty(userParamsStr)) {
        TypeToken<Map<String, String>> typeToken = new TypeToken<Map<String, String>>() {
        };
        Map<String, String> userParams = new Gson().fromJson(userParamsStr, typeToken.getType());
        userParams.forEach(algoParams::put);
    }
    DLLauncherBatchOp dlLauncherBatchOp = new DLLauncherBatchOp().setOutputSchemaStr(MODEL_SCHEMA_STR).setNumWorkers(getNumWorkers()).setNumPSs(numPss).setEntryFunc(entryFuncName).setPythonEnv(getPythonEnv()).setUserFiles(externalFiles).setMainScriptFile(mainScriptFileName).setUserParams(JsonConverter.toJson(algoParams)).setIntraOpParallelism(getIntraOpParallelism()).setMLEnvironmentId(getMLEnvironmentId());
    BatchOperator<?>[] tfInputs = new BatchOperator<?>[inputs.length];
    tfInputs[0] = input;
    System.arraycopy(inputs, 1, tfInputs, 1, inputs.length - 1);
    BatchOperator<?> tfModel = dlLauncherBatchOp.linkFrom(tfInputs);
    this.setOutputTable(tfModel.getOutputTable());
    return (T) this;
}
Also used : HashMap(java.util.HashMap) Gson(com.google.gson.Gson) BatchOperator(com.alibaba.alink.operator.batch.BatchOperator) TypeToken(com.google.gson.reflect.TypeToken) Map(java.util.Map) HashMap(java.util.HashMap) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)

Example 7 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextClassifierTest method test.

@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
    String schemaStr = "label bigint, review string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setIgnoreFirstLine(true);
    data = data.where("review is not null");
    data = new ShuffleBatchOp().linkFrom(data);
    BertTextClassifier classifier = new BertTextClassifier().setTextCol("review").setLabelCol("label").setNumEpochs(0.01).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    BertClassificationModel model = classifier.fit(data);
    BatchOperator<?> predict = model.transform(data.firstN(300));
    predict.print();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Test(org.junit.Test)

Example 8 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextRegressorTest method test.

@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
    String schemaStr = "label double, review string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setIgnoreFirstLine(true);
    data = data.where("review is not null");
    data = new ShuffleBatchOp().linkFrom(data);
    BertTextRegressor regressor = new BertTextRegressor().setTextCol("review").setLabelCol("label").setNumEpochs(0.01).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setPredictionCol("pred");
    BertRegressionModel model = regressor.fit(data);
    BatchOperator<?> predict = model.transform(data.firstN(300));
    predict.print();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Test(org.junit.Test)

Example 9 with ShuffleBatchOp

use of com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp in project Alink by alibaba.

the class BertTextRegressorTrainBatchOpTest method test.

@Category(DLTest.class)
@Test
public void test() throws Exception {
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(2);
    String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
    String schema = "label double, review string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setIgnoreFirstLine(true);
    data = data.where("review is not null");
    data = new ShuffleBatchOp().linkFrom(data);
    Map<String, Map<String, Object>> customConfig = new HashMap<>();
    customConfig.put("train_config", ImmutableMap.of("optimizer_config", ImmutableMap.of("learning_rate", 0.01)));
    BertTextRegressorTrainBatchOp train = new BertTextRegressorTrainBatchOp().setTextCol("review").setLabelCol("label").setNumEpochs(0.05).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setCustomConfigJson(JsonConverter.toJson(customConfig)).linkFrom(data);
    Assert.assertTrue(train.count() > 1);
    BatchOperator.setParallelism(savedParallelism);
}
Also used : HashMap(java.util.HashMap) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap) Map(java.util.Map) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

ShuffleBatchOp (com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)9 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)7 Test (org.junit.Test)7 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)4 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)4 HashMap (java.util.HashMap)4 Map (java.util.Map)4 DLTest (com.alibaba.alink.testutil.categories.DLTest)3 ImmutableMap (com.google.common.collect.ImmutableMap)3 Category (org.junit.experimental.categories.Category)3 BatchOperator (com.alibaba.alink.operator.batch.BatchOperator)1 Gson (com.google.gson.Gson)1 TypeToken (com.google.gson.reflect.TypeToken)1