Search in sources :

Example 6 with LongTensor

use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.

the class BaseTFSavedModelPredictMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    int batchSize = 3;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath:" + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictMapper baseTFSavedModelPredictMapper = new BaseTFSavedModelPredictMapper(data.getSchema(), params);
    baseTFSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
    }
    baseTFSavedModelPredictMapper.close();
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)6 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)6 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)4 Shape (com.alibaba.alink.common.linalg.tensor.Shape)3 Test (org.junit.Test)3 Tensor (org.tensorflow.Tensor)3 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)2 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)2 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)2 DLTest (com.alibaba.alink.testutil.categories.DLTest)2 ArrayList (java.util.ArrayList)2 Params (org.apache.flink.ml.api.misc.param.Params)2 Row (org.apache.flink.types.Row)2 Category (org.junit.experimental.categories.Category)2 TensorInfo (org.tensorflow.proto.framework.TensorInfo)2 TF2TensorUtils (com.alibaba.alink.common.dl.utils.TF2TensorUtils)1 TFTensorConversionUtils (com.alibaba.alink.common.dl.utils.TFTensorConversionUtils)1