Search in sources :

Example 6 with StringTensor

use of com.alibaba.alink.common.linalg.tensor.StringTensor in project Alink by alibaba.

the class ToTensorMapperTest method testStringType.

@Test
public void testStringType() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "str" }, new TypeInformation<?>[] { Types.STRING }), new Params().set(ToTensorParams.SELECTED_COL, "str").set(ToTensorParams.TENSOR_DATA_TYPE, DataType.STRING));
    Assert.assertEquals(TensorTypes.STRING_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final StringTensor expect = new StringTensor(tensor.toString());
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toString())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Aggregations

StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)6 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)5 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)5 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)4 Test (org.junit.Test)3 Tensor (org.tensorflow.Tensor)3 TensorInfo (org.tensorflow.proto.framework.TensorInfo)2 TString (org.tensorflow.types.TString)2 TF2TensorUtils (com.alibaba.alink.common.dl.utils.TF2TensorUtils)1 TFTensorConversionUtils (com.alibaba.alink.common.dl.utils.TFTensorConversionUtils)1 ByteTensor (com.alibaba.alink.common.linalg.tensor.ByteTensor)1 Shape (com.alibaba.alink.common.linalg.tensor.Shape)1 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)1 Tensor.unstack (com.alibaba.alink.common.linalg.tensor.Tensor.unstack)1 UByteTensor (com.alibaba.alink.common.linalg.tensor.UByteTensor)1 Mapper (com.alibaba.alink.common.mapper.Mapper)1 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)1 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)1