use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BaseTFSavedModelPredictRowMapperTest method testTensor.
@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
List<Row> rows = new ArrayList<>();
for (int i = 0; i < 1000; i += 1) {
Row row = Row.of(0, new FloatTensor(new Shape(28, 28)));
rows.add(row);
}
BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG, image FLOAT_TENSOR");
String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
String workDir = PythonFileUtils.createTempDir("temp_").toString();
String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
String localModelPath = workDir + File.separator + fn;
System.out.println("localModelPath: " + localModelPath);
if (localModelPath.endsWith(".zip")) {
File target = new File(localModelPath).getParentFile();
ZipFileUtil.unZip(new File(localModelPath), target);
localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
}
Params params = new Params();
params.set(HasModelPath.MODEL_PATH, localModelPath);
params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG, probabilities FLOAT_TENSOR");
BaseTFSavedModelPredictRowMapper baseTFSavedModelPredictRowMapper = new BaseTFSavedModelPredictRowMapper(data.getSchema(), params);
baseTFSavedModelPredictRowMapper.open();
Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", TensorTypes.FLOAT_TENSOR).field("classes", Types.LONG).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictRowMapper.getOutputSchema());
for (Row row : rows) {
Row output = baseTFSavedModelPredictRowMapper.map(row);
Assert.assertEquals(row.getField(0), output.getField(0));
Assert.assertEquals(row.getField(1), output.getField(1));
Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { 10 });
}
baseTFSavedModelPredictRowMapper.close();
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TFSavedModelPredictMapperTest method testTensor.
@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
int batchSize = 3;
List<Row> rows = new ArrayList<>();
for (int i = 0; i < 1000; i += 1) {
Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
rows.add(row);
}
BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
Params params = new Params();
params.set(HasModelPath.MODEL_PATH, "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip");
params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
TFSavedModelPredictMapper tfSavedModelPredictMapper = new TFSavedModelPredictMapper(data.getSchema(), params);
tfSavedModelPredictMapper.open();
Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), tfSavedModelPredictMapper.getOutputSchema());
for (Row row : rows) {
Row output = tfSavedModelPredictMapper.map(row);
Assert.assertEquals(row.getField(0), output.getField(0));
Assert.assertEquals(row.getField(1), output.getField(1));
Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
}
tfSavedModelPredictMapper.close();
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BertTextPairClassifierTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
data = new ShuffleBatchOp().linkFrom(data);
BertTextPairClassifier classifier = new BertTextPairClassifier().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
BertClassificationModel model = classifier.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BertTextPairRegressorTest method test.
@Test
public void test() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
data = new ShuffleBatchOp().linkFrom(data);
BertTextPairRegressor regressor = new BertTextPairRegressor().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred");
BertRegressionModel model = regressor.fit(data);
BatchOperator<?> predict = model.transform(data.firstN(300));
predict.print();
BatchOperator.setParallelism(savedParallelism);
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class Chap25 method c_1_3.
static void c_1_3() throws Exception {
System.out.println(AlinkGlobalConfiguration.getPluginDir());
System.out.print("Auto Plugin Download : ");
System.out.println(AlinkGlobalConfiguration.getAutoPluginDownload());
PluginDownloader downloader = AlinkGlobalConfiguration.getPluginDownloader();
System.out.println(downloader.listAvailablePlugins());
}
Aggregations