use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BaseTFSavedModelPredictMapperTest method testTensor.
@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
int batchSize = 3;
List<Row> rows = new ArrayList<>();
for (int i = 0; i < 1000; i += 1) {
Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
rows.add(row);
}
BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
String workDir = PythonFileUtils.createTempDir("temp_").toString();
String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
String localModelPath = workDir + File.separator + fn;
System.out.println("localModelPath:" + localModelPath);
if (localModelPath.endsWith(".zip")) {
File target = new File(localModelPath).getParentFile();
ZipFileUtil.unZip(new File(localModelPath), target);
localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
}
Params params = new Params();
params.set(HasModelPath.MODEL_PATH, localModelPath);
params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
BaseTFSavedModelPredictMapper baseTFSavedModelPredictMapper = new BaseTFSavedModelPredictMapper(data.getSchema(), params);
baseTFSavedModelPredictMapper.open();
Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictMapper.getOutputSchema());
for (Row row : rows) {
Row output = baseTFSavedModelPredictMapper.map(row);
Assert.assertEquals(row.getField(0), output.getField(0));
Assert.assertEquals(row.getField(1), output.getField(1));
Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
}
baseTFSavedModelPredictMapper.close();
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class BaseTFSavedModelPredictRowMapperTest method testString.
@Category(DLTest.class)
@Test
public void testString() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
String url = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/mnist_dense.csv";
String schema = "label bigint, image string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setFieldDelimiter(";");
List<Row> rows = data.collect();
String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
String workDir = PythonFileUtils.createTempDir("temp_").toString();
String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
String localModelPath = workDir + File.separator + fn;
System.out.println("localModelPath: " + localModelPath);
if (localModelPath.endsWith(".zip")) {
File target = new File(localModelPath).getParentFile();
ZipFileUtil.unZip(new File(localModelPath), target);
localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
}
Params params = new Params();
params.set(HasModelPath.MODEL_PATH, localModelPath);
params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes bigint, probabilities string");
BaseTFSavedModelPredictRowMapper baseTFSavedModelPredictRowMapper = new BaseTFSavedModelPredictRowMapper(data.getSchema(), params);
baseTFSavedModelPredictRowMapper.open();
Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", Types.STRING).field("classes", Types.LONG).field("probabilities", Types.STRING).build(), baseTFSavedModelPredictRowMapper.getOutputSchema());
for (Row row : rows) {
Row output = baseTFSavedModelPredictRowMapper.map(row);
Assert.assertEquals(row.getField(0), output.getField(0));
Assert.assertEquals(row.getField(1), output.getField(1));
}
baseTFSavedModelPredictRowMapper.close();
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TFSavedModelPredictMapperTest method testString.
@Category(DLTest.class)
@Test
public void testString() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
String url = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/mnist_dense.csv";
String schema = "label bigint, image string";
BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setFieldDelimiter(";");
List<Row> rows = data.collect();
Params params = new Params();
params.set(HasModelPath.MODEL_PATH, "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip");
params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes bigint, probabilities string");
TFSavedModelPredictMapper tfSavedModelPredictMapper = new TFSavedModelPredictMapper(data.getSchema(), params);
tfSavedModelPredictMapper.open();
Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", Types.STRING).field("classes", Types.LONG).field("probabilities", Types.STRING).build(), tfSavedModelPredictMapper.getOutputSchema());
for (Row row : rows) {
Row output = tfSavedModelPredictMapper.map(row);
Assert.assertEquals(row.getField(0), output.getField(0));
Assert.assertEquals(row.getField(1), output.getField(1));
}
tfSavedModelPredictMapper.close();
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TensorFlow2StreamOpTest method testAllReduce.
@Test
public void testAllReduce() throws Exception {
int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF231);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
StreamOperator.setParallelism(3);
DLLauncherStreamOp.DL_CLUSTER_START_TIME = 30 * 1000;
StreamOperator<?> source = new RandomTableSourceStreamOp().setMaxRows(1000L).setNumCols(10);
String[] colNames = source.getColNames();
source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
source = source.link(new TypeConvertStreamOp().setSelectedCols("num").setTargetType(TargetType.DOUBLE));
String label = "label";
Map<String, Object> userParams = new HashMap<>();
userParams.put("featureCols", JsonConverter.toJson(colNames));
userParams.put("labelCol", label);
userParams.put("batch_size", 16);
userParams.put("num_epochs", 1);
TensorFlow2StreamOp tensorFlow2StreamOp = new TensorFlow2StreamOp().setUserFiles(new String[] { "res:///tf_dnn_stream.py" }).setMainScriptFile("res:///tf_dnn_stream.py").setUserParams(JsonConverter.toJson(userParams)).setNumWorkers(3).setNumPSs(0).setOutputSchemaStr("model_id long, model_info string").linkFrom(source);
tensorFlow2StreamOp.print();
StreamOperator.execute();
StreamOperator.setParallelism(savedStreamParallelism);
}
use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.
the class TensorFlow2StreamOpTest method testWithAutoWorkersPSs.
@Test
public void testWithAutoWorkersPSs() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF231);
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
StreamOperator.setParallelism(3);
DLLauncherStreamOp.DL_CLUSTER_START_TIME = 30 * 1000;
StreamOperator<?> source = new RandomTableSourceStreamOp().setMaxRows(1000L).setNumCols(10);
String[] colNames = source.getColNames();
source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
source = source.link(new TypeConvertStreamOp().setSelectedCols("num").setTargetType(TargetType.DOUBLE));
String label = "label";
Map<String, Object> userParams = new HashMap<>();
userParams.put("featureCols", JsonConverter.toJson(colNames));
userParams.put("labelCol", label);
userParams.put("batch_size", 16);
userParams.put("num_epochs", 1);
TensorFlow2StreamOp tensorFlow2StreamOp = new TensorFlow2StreamOp().setUserFiles(new String[] { "res:///tf_dnn_stream.py" }).setMainScriptFile("res:///tf_dnn_stream.py").setUserParams(JsonConverter.toJson(userParams)).setOutputSchemaStr("model_id long, model_info string").linkFrom(source);
tensorFlow2StreamOp.print();
StreamOperator.execute();
StreamOperator.setParallelism(savedStreamParallelism);
}
Aggregations