Search in sources :

Example 11 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class ToTensorMapperTest method testReshape.

@Test
public void testReshape() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 12 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class ToTensorMapperTest method testFloatType.

@Test
public void testFloatType() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(ToTensorParams.TENSOR_DATA_TYPE, DataType.FLOAT));
    Assert.assertEquals(TensorTypes.FLOAT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 13 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class BaseTFSavedModelPredictMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    int batchSize = 3;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath:" + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictMapper baseTFSavedModelPredictMapper = new BaseTFSavedModelPredictMapper(data.getSchema(), params);
    baseTFSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
    }
    baseTFSavedModelPredictMapper.close();
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)13 Shape (com.alibaba.alink.common.linalg.tensor.Shape)13 Test (org.junit.Test)12 Params (org.apache.flink.ml.api.misc.param.Params)10 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)8 Row (org.apache.flink.types.Row)7 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)6 TableSchema (org.apache.flink.table.api.TableSchema)6 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)5 Mapper (com.alibaba.alink.common.mapper.Mapper)5 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)5 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)4 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)4 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)4 DLTest (com.alibaba.alink.testutil.categories.DLTest)4 ArrayList (java.util.ArrayList)4 Category (org.junit.experimental.categories.Category)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)3 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)3 File (java.io.File)3