Search in sources :

Example 11 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class ToTensorMapperTest method testDefaultType.

@Test
public void testDefaultType() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
    Assert.assertEquals(TensorTypes.TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 12 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class ToTensorMapperTest method testOp.

@Test
public void testOp() throws Exception {
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
    Row[] rows = new Row[] { Row.of(tensor.toVector()) };
    MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(rows, new String[] { "vec" });
    memSourceBatchOp.link(new ToTensorBatchOp().setSelectedCol("vec").setTensorShape(2, 3).setTensorDataType("float")).lazyCollect(new Consumer<List<Row>>() {

        @Override
        public void accept(List<Row> rows) {
            Assert.assertEquals(expect, TensorUtil.getTensor(rows.get(0).getField(0)));
        }
    });
    BatchOperator.execute();
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ToTensorBatchOp(com.alibaba.alink.operator.batch.dataproc.ToTensorBatchOp) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) List(java.util.List) Row(org.apache.flink.types.Row) Test(org.junit.Test)

Example 13 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class ToTensorMapperTest method testHandleInvalidSkip.

@Test
public void testHandleInvalidSkip() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(ToTensorParams.TENSOR_DATA_TYPE, DataType.INT).set(ToTensorParams.HANDLE_INVALID, HandleInvalidMethod.SKIP));
    Assert.assertEquals(TensorTypes.INT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertNull(result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 14 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class ToTensorMapperTest method testHandleInvalidError.

@Test(expected = IllegalArgumentException.class)
public void testHandleInvalidError() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(ToTensorParams.TENSOR_DATA_TYPE, DataType.INT));
    Assert.assertEquals(TensorTypes.INT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    mapper.map(Row.of(tensor.toVector())).getField(0);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) TableSchema(org.apache.flink.table.api.TableSchema) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 15 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class ToTensorMapperTest method testStringType.

@Test
public void testStringType() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "str" }, new TypeInformation<?>[] { Types.STRING }), new Params().set(ToTensorParams.SELECTED_COL, "str").set(ToTensorParams.TENSOR_DATA_TYPE, DataType.STRING));
    Assert.assertEquals(TensorTypes.STRING_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final StringTensor expect = new StringTensor(tensor.toString());
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toString())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Aggregations

DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)21 Test (org.junit.Test)18 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)15 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)10 Mapper (com.alibaba.alink.common.mapper.Mapper)10 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)10 Params (org.apache.flink.ml.api.misc.param.Params)10 TableSchema (org.apache.flink.table.api.TableSchema)10 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)9 Shape (com.alibaba.alink.common.linalg.tensor.Shape)8 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)7 Row (org.apache.flink.types.Row)5 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 Tensor (org.tensorflow.Tensor)4 TFloat64 (org.tensorflow.types.TFloat64)4 VectorToTensorParams (com.alibaba.alink.params.dataproc.VectorToTensorParams)3 List (java.util.List)3