Search in sources :

Example 21 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class TFTableModelRegressionModelMapperTest method test.

@Category(DLTest.class)
@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    List<Row> baseData = Arrays.asList(Row.of(1.2, 3.4, 10L, 3L, "yes", 0.), Row.of(1.2, 3.4, 2L, 5L, "no", 0.2), Row.of(1.2, 3.4, 6L, 8L, "no", 0.4), Row.of(1.2, 3.4, 3L, 2L, "yes", 1.0));
    String dataSchemaStr = "f double, d double, i long, l long, s string, label double";
    Random random = new Random();
    List<Row> data = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        data.add(baseData.get(random.nextInt(baseData.size())));
    }
    InputStream resourceAsStream = getClass().getClassLoader().getResourceAsStream("tf_table_model_regression_model.ak");
    String modelPath = Files.createTempFile("tf_table_model_regression_model", ".ak").toString();
    assert resourceAsStream != null;
    FileUtils.copyInputStreamToFile(resourceAsStream, new File(modelPath));
    BatchOperator<?> modelOp = new AkSourceBatchOp().setFilePath(modelPath);
    List<Row> modelRows = modelOp.collect();
    Params params = new Params();
    params.set(HasPredictionCol.PREDICTION_COL, "pred");
    params.set(HasReservedColsDefaultAsNull.RESERVED_COLS, new String[] { "s", "label" });
    TFTableModelRegressionModelMapper mapper = new TFTableModelRegressionModelMapper(modelOp.getSchema(), CsvUtil.schemaStr2Schema(dataSchemaStr), params);
    mapper.loadModel(modelRows);
    mapper.open();
    Assert.assertEquals(TableSchema.builder().field("s", Types.STRING).field("label", Types.DOUBLE).field("pred", Types.DOUBLE).build(), mapper.getOutputSchema());
    for (Row row : data) {
        Row output = mapper.map(row);
        Assert.assertEquals(3, output.getArity());
        Assert.assertEquals(row.getField(4), output.getField(0));
        Assert.assertEquals(row.getField(5), output.getField(1));
    }
    mapper.close();
}
Also used : InputStream(java.io.InputStream) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) AkSourceBatchOp(com.alibaba.alink.operator.batch.source.AkSourceBatchOp) Random(java.util.Random) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)21 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)19 Test (org.junit.Test)19 DLTest (com.alibaba.alink.testutil.categories.DLTest)15 Params (org.apache.flink.ml.api.misc.param.Params)12 Row (org.apache.flink.types.Row)12 Category (org.junit.experimental.categories.Category)12 File (java.io.File)10 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)8 ArrayList (java.util.ArrayList)8 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)4 Shape (com.alibaba.alink.common.linalg.tensor.Shape)4 ShuffleBatchOp (com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)3 TypeConvertStreamOp (com.alibaba.alink.operator.stream.dataproc.TypeConvertStreamOp)3 RandomTableSourceStreamOp (com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp)3 InputStream (java.io.InputStream)3 HashMap (java.util.HashMap)3 Random (java.util.Random)3