use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TensorFlowStreamOpTest method testWithAutoWorkersPSs.
@Test
public void testWithAutoWorkersPSs() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
StreamOperator.setParallelism(3);
DLLauncherStreamOp.DL_CLUSTER_START_TIME = 30 * 1000;
StreamOperator<?> source = new RandomTableSourceStreamOp().setMaxRows(1000L).setNumCols(10);
String[] colNames = source.getColNames();
source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
source = source.link(new TypeConvertStreamOp().setSelectedCols("num").setTargetType(TargetType.DOUBLE));
String label = "label";
Map<String, Object> userParams = new HashMap<>();
userParams.put("featureCols", JsonConverter.toJson(colNames));
userParams.put("labelCol", label);
userParams.put("batch_size", 16);
userParams.put("num_epochs", 1);
TensorFlowStreamOp tensorFlowStreamOp = new TensorFlowStreamOp().setUserFiles(new String[] { "res:///tf_dnn_stream.py" }).setMainScriptFile("res:///tf_dnn_stream.py").setUserParams(JsonConverter.toJson(userParams)).setOutputSchemaStr("model_id long, model_info string").linkFrom(source);
tensorFlowStreamOp.print();
StreamOperator.execute();
StreamOperator.setParallelism(savedStreamParallelism);
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BertTextClassifierTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
String schemaStr = "label bigint, review string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setIgnoreFirstLine(true);
data = data.where("review is not null");
data = new ShuffleBatchOp().linkFrom(data);
BertTextClassifier classifier = new BertTextClassifier().setTextCol("review").setLabelCol("label").setNumEpochs(0.01).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
BertClassificationModel model = classifier.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BertTextRegressorTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = DLTestConstants.CHN_SENTI_CORP_HTL_PATH;
String schemaStr = "label double, review string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setIgnoreFirstLine(true);
data = data.where("review is not null");
data = new ShuffleBatchOp().linkFrom(data);
BertTextRegressor regressor = new BertTextRegressor().setTextCol("review").setLabelCol("label").setNumEpochs(0.01).setNumFineTunedLayers(1).setMaxSeqLength(128).setBertModelName("Base-Chinese").setPredictionCol("pred");
BertRegressionModel model = regressor.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TFTableModelClassificationFlatModelMapperTest method test.
@Category(DLTest.class)
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
List<Row> baseData = Arrays.asList(Row.of((float) 1.2, 3.4, 10, 3L, "bad"), Row.of((float) 1.2, 3.4, 2, 5L, "good"), Row.of((float) 1.2, 3.4, 6, 8L, "bad"), Row.of((float) 1.2, 3.4, 3, 2L, "good"));
String dataSchemaStr = "f float, d double, i int, l long, label string";
Random random = new Random();
List<Row> data = new ArrayList<>();
for (int i = 0; i < 1000; i += 1) {
data.add(baseData.get(random.nextInt(baseData.size())));
}
InputStream resourceAsStream = getClass().getClassLoader().getResourceAsStream("tf_table_model_binary_class_model.ak");
String modelPath = Files.createTempFile("tf_table_model_binary_class_model", ".ak").toString();
assert resourceAsStream != null;
FileUtils.copyInputStreamToFile(resourceAsStream, new File(modelPath));
BatchOperator<?> modelOp = new AkSourceBatchOp().setFilePath(modelPath);
List<Row> modelRows = modelOp.collect();
Params params = new Params();
params.set(HasPredictionCol.PREDICTION_COL, "pred");
params.set(HasPredictionDetailCol.PREDICTION_DETAIL_COL, "pred_detail");
params.set(HasReservedColsDefaultAsNull.RESERVED_COLS, new String[] { "l", "label" });
TFTableModelClassificationFlatModelMapper mapper = new TFTableModelClassificationFlatModelMapper(modelOp.getSchema(), CsvUtil.schemaStr2Schema(dataSchemaStr), params);
mapper.loadModel(modelRows);
List<Row> list = new ArrayList<>();
ListCollector<Row> collector = new ListCollector<>(list);
mapper.open();
for (Row row : data) {
mapper.flatMap(row, collector);
}
mapper.close();
Assert.assertEquals(TableSchema.builder().field("l", Types.LONG).field("label", Types.STRING).field("pred", Types.STRING).field("pred_detail", Types.STRING).build(), mapper.getOutputSchema());
Assert.assertEquals(data.size(), list.size());
for (int i = 0; i < data.size(); i += 1) {
Assert.assertEquals(4, list.get(i).getArity());
Assert.assertEquals(data.get(i).getField(3), list.get(i).getField(0));
Assert.assertEquals(data.get(i).getField(4), list.get(i).getField(1));
}
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TFTableModelClassificationModelMapperTest method test.
@Category(DLTest.class)
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
List<Row> baseData = Arrays.asList(Row.of((float) 1.2, 3.4, 10, 3L, "bad"), Row.of((float) 1.2, 3.4, 2, 5L, "good"), Row.of((float) 1.2, 3.4, 6, 8L, "bad"), Row.of((float) 1.2, 3.4, 3, 2L, "good"));
String dataSchemaStr = "f float, d double, i int, l long, label string";
Random random = new Random();
List<Row> data = new ArrayList<>();
for (int i = 0; i < 1000; i += 1) {
data.add(baseData.get(random.nextInt(baseData.size())));
}
InputStream resourceAsStream = getClass().getClassLoader().getResourceAsStream("tf_table_model_binary_class_model.ak");
String modelPath = Files.createTempFile("tf_table_model_binary_class_model", ".ak").toString();
assert resourceAsStream != null;
FileUtils.copyInputStreamToFile(resourceAsStream, new File(modelPath));
BatchOperator<?> modelOp = new AkSourceBatchOp().setFilePath(modelPath);
List<Row> modelRows = modelOp.collect();
Params params = new Params();
params.set(HasPredictionCol.PREDICTION_COL, "pred");
params.set(HasPredictionDetailCol.PREDICTION_DETAIL_COL, "pred_detail");
params.set(HasReservedColsDefaultAsNull.RESERVED_COLS, new String[] { "l", "label" });
TFTableModelClassificationModelMapper mapper = new TFTableModelClassificationModelMapper(modelOp.getSchema(), CsvUtil.schemaStr2Schema(dataSchemaStr), params);
mapper.loadModel(modelRows);
mapper.open();
Assert.assertEquals(TableSchema.builder().field("l", Types.LONG).field("label", Types.STRING).field("pred", Types.STRING).field("pred_detail", Types.STRING).build(), mapper.getOutputSchema());
for (Row row : data) {
Row output = mapper.map(row);
Assert.assertEquals(4, output.getArity());
Assert.assertEquals(row.getField(3), output.getField(0));
Assert.assertEquals(row.getField(4), output.getField(1));
}
mapper.close();
}
Aggregations