Search in sources :

Example 11 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BaseTFSavedModelPredictMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    int batchSize = 3;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath:" + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictMapper baseTFSavedModelPredictMapper = new BaseTFSavedModelPredictMapper(data.getSchema(), params);
    baseTFSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
    }
    baseTFSavedModelPredictMapper.close();
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 12 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BaseTFSavedModelPredictRowMapperTest method testString.

@Category(DLTest.class)
@Test
public void testString() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    String url = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/mnist_dense.csv";
    String schema = "label bigint, image string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setFieldDelimiter(";");
    List<Row> rows = data.collect();
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath: " + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes bigint, probabilities string");
    BaseTFSavedModelPredictRowMapper baseTFSavedModelPredictRowMapper = new BaseTFSavedModelPredictRowMapper(data.getSchema(), params);
    baseTFSavedModelPredictRowMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", Types.STRING).field("classes", Types.LONG).field("probabilities", Types.STRING).build(), baseTFSavedModelPredictRowMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictRowMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
    }
    baseTFSavedModelPredictRowMapper.close();
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) Params(org.apache.flink.ml.api.misc.param.Params) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 13 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class TFSavedModelPredictMapperTest method testString.

@Category(DLTest.class)
@Test
public void testString() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    String url = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/mnist_dense.csv";
    String schema = "label bigint, image string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setFieldDelimiter(";");
    List<Row> rows = data.collect();
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip");
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes bigint, probabilities string");
    TFSavedModelPredictMapper tfSavedModelPredictMapper = new TFSavedModelPredictMapper(data.getSchema(), params);
    tfSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", Types.STRING).field("classes", Types.LONG).field("probabilities", Types.STRING).build(), tfSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = tfSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
    }
    tfSavedModelPredictMapper.close();
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) Params(org.apache.flink.ml.api.misc.param.Params) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 14 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class TensorFlow2StreamOpTest method testAllReduce.

@Test
public void testAllReduce() throws Exception {
    int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF231);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    StreamOperator.setParallelism(3);
    DLLauncherStreamOp.DL_CLUSTER_START_TIME = 30 * 1000;
    StreamOperator<?> source = new RandomTableSourceStreamOp().setMaxRows(1000L).setNumCols(10);
    String[] colNames = source.getColNames();
    source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
    source = source.link(new TypeConvertStreamOp().setSelectedCols("num").setTargetType(TargetType.DOUBLE));
    String label = "label";
    Map<String, Object> userParams = new HashMap<>();
    userParams.put("featureCols", JsonConverter.toJson(colNames));
    userParams.put("labelCol", label);
    userParams.put("batch_size", 16);
    userParams.put("num_epochs", 1);
    TensorFlow2StreamOp tensorFlow2StreamOp = new TensorFlow2StreamOp().setUserFiles(new String[] { "res:///tf_dnn_stream.py" }).setMainScriptFile("res:///tf_dnn_stream.py").setUserParams(JsonConverter.toJson(userParams)).setNumWorkers(3).setNumPSs(0).setOutputSchemaStr("model_id long, model_info string").linkFrom(source);
    tensorFlow2StreamOp.print();
    StreamOperator.execute();
    StreamOperator.setParallelism(savedStreamParallelism);
}
Also used : TypeConvertStreamOp(com.alibaba.alink.operator.stream.dataproc.TypeConvertStreamOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) HashMap(java.util.HashMap) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) RandomTableSourceStreamOp(com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 15 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class TensorFlow2StreamOpTest method testWithAutoWorkersPSs.

@Test
public void testWithAutoWorkersPSs() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF231);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedStreamParallelism = MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().getParallelism();
    StreamOperator.setParallelism(3);
    DLLauncherStreamOp.DL_CLUSTER_START_TIME = 30 * 1000;
    StreamOperator<?> source = new RandomTableSourceStreamOp().setMaxRows(1000L).setNumCols(10);
    String[] colNames = source.getColNames();
    source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
    source = source.link(new TypeConvertStreamOp().setSelectedCols("num").setTargetType(TargetType.DOUBLE));
    String label = "label";
    Map<String, Object> userParams = new HashMap<>();
    userParams.put("featureCols", JsonConverter.toJson(colNames));
    userParams.put("labelCol", label);
    userParams.put("batch_size", 16);
    userParams.put("num_epochs", 1);
    TensorFlow2StreamOp tensorFlow2StreamOp = new TensorFlow2StreamOp().setUserFiles(new String[] { "res:///tf_dnn_stream.py" }).setMainScriptFile("res:///tf_dnn_stream.py").setUserParams(JsonConverter.toJson(userParams)).setOutputSchemaStr("model_id long, model_info string").linkFrom(source);
    tensorFlow2StreamOp.print();
    StreamOperator.execute();
    StreamOperator.setParallelism(savedStreamParallelism);
}
Also used : TypeConvertStreamOp(com.alibaba.alink.operator.stream.dataproc.TypeConvertStreamOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) HashMap(java.util.HashMap) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) RandomTableSourceStreamOp(com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)21 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)19 Test (org.junit.Test)19 DLTest (com.alibaba.alink.testutil.categories.DLTest)15 Params (org.apache.flink.ml.api.misc.param.Params)12 Row (org.apache.flink.types.Row)12 Category (org.junit.experimental.categories.Category)12 File (java.io.File)10 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)8 ArrayList (java.util.ArrayList)8 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)4 Shape (com.alibaba.alink.common.linalg.tensor.Shape)4 ShuffleBatchOp (com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)3 TypeConvertStreamOp (com.alibaba.alink.operator.stream.dataproc.TypeConvertStreamOp)3 RandomTableSourceStreamOp (com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp)3 InputStream (java.io.InputStream)3 HashMap (java.util.HashMap)3 Random (java.util.Random)3