Search in sources :

Example 6 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BaseTFSavedModelPredictRowMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(0, new FloatTensor(new Shape(28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath: " + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictRowMapper baseTFSavedModelPredictRowMapper = new BaseTFSavedModelPredictRowMapper(data.getSchema(), params);
    baseTFSavedModelPredictRowMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", TensorTypes.FLOAT_TENSOR).field("classes", Types.LONG).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictRowMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictRowMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { 10 });
    }
    baseTFSavedModelPredictRowMapper.close();
}
Also used : Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 7 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class TFSavedModelPredictMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    int batchSize = 3;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip");
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
    TFSavedModelPredictMapper tfSavedModelPredictMapper = new TFSavedModelPredictMapper(data.getSchema(), params);
    tfSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), tfSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = tfSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
    }
    tfSavedModelPredictMapper.close();
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 8 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BertTextPairClassifierTest method test.

@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
    String schemaStr = "f_quality bigint, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
    data = new ShuffleBatchOp().linkFrom(data);
    BertTextPairClassifier classifier = new BertTextPairClassifier().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred").setPredictionDetailCol("pred_detail");
    BertClassificationModel model = classifier.fit(data);
    BatchOperator<?> predict = model.transform(data.firstN(300));
    predict.print();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Test(org.junit.Test)

Example 9 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BertTextPairRegressorTest method test.

@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = DLEnvConfig.getRegisterKey(Version.TF115);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.CKPT);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    String url = "http://alink-algo-packages.oss-cn-hangzhou-zmf.aliyuncs.com/data/MRPC/train.tsv";
    String schemaStr = "f_quality double, f_id_1 string, f_id_2 string, f_string_1 string, f_string_2 string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schemaStr).setFieldDelimiter("\t").setIgnoreFirstLine(true).setQuoteChar(null);
    data = new ShuffleBatchOp().linkFrom(data);
    BertTextPairRegressor regressor = new BertTextPairRegressor().setTextCol("f_string_1").setTextPairCol("f_string_2").setLabelCol("f_quality").setNumEpochs(0.1).setMaxSeqLength(32).setNumFineTunedLayers(1).setBertModelName("Base-Uncased").setPredictionCol("pred");
    BertRegressionModel model = regressor.fit(data);
    BatchOperator<?> predict = model.transform(data.firstN(300));
    predict.print();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) ShuffleBatchOp(com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp) Test(org.junit.Test)

Example 10 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class Chap25 method c_1_3.

static void c_1_3() throws Exception {
    System.out.println(AlinkGlobalConfiguration.getPluginDir());
    System.out.print("Auto Plugin Download : ");
    System.out.println(AlinkGlobalConfiguration.getAutoPluginDownload());
    PluginDownloader downloader = AlinkGlobalConfiguration.getPluginDownloader();
    System.out.println(downloader.listAvailablePlugins());
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader)

Aggregations

PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)21 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)19 Test (org.junit.Test)19 DLTest (com.alibaba.alink.testutil.categories.DLTest)15 Params (org.apache.flink.ml.api.misc.param.Params)12 Row (org.apache.flink.types.Row)12 Category (org.junit.experimental.categories.Category)12 File (java.io.File)10 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)8 ArrayList (java.util.ArrayList)8 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)4 Shape (com.alibaba.alink.common.linalg.tensor.Shape)4 ShuffleBatchOp (com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)3 TypeConvertStreamOp (com.alibaba.alink.operator.stream.dataproc.TypeConvertStreamOp)3 RandomTableSourceStreamOp (com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp)3 InputStream (java.io.InputStream)3 HashMap (java.util.HashMap)3 Random (java.util.Random)3