use of com.alibaba.alink.common.linalg.tensor.Tensor in project Alink by alibaba.
the class VectorToTensorMapperTest method testFloatType.
@Test
public void testFloatType() throws Exception {
final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(VectorToTensorParams.TENSOR_DATA_TYPE, DataType.FLOAT));
Assert.assertEquals(TensorTypes.FLOAT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.linalg.tensor.Tensor in project Alink by alibaba.
the class VectorToTensorMapperTest method testReshape.
@Test
public void testReshape() throws Exception {
final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.linalg.tensor.Tensor in project Alink by alibaba.
the class ToTensorMapperTest method testDefaultType.
@Test
public void testDefaultType() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
Assert.assertEquals(TensorTypes.TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.linalg.tensor.Tensor in project Alink by alibaba.
the class ToTensorMapperTest method testHandleInvalidSkip.
@Test
public void testHandleInvalidSkip() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(ToTensorParams.TENSOR_DATA_TYPE, DataType.INT).set(ToTensorParams.HANDLE_INVALID, HandleInvalidMethod.SKIP));
Assert.assertEquals(TensorTypes.INT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertNull(result);
}
use of com.alibaba.alink.common.linalg.tensor.Tensor in project Alink by alibaba.
the class ToTensorMapperTest method testStringType.
@Test
public void testStringType() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "str" }, new TypeInformation<?>[] { Types.STRING }), new Params().set(ToTensorParams.SELECTED_COL, "str").set(ToTensorParams.TENSOR_DATA_TYPE, DataType.STRING));
Assert.assertEquals(TensorTypes.STRING_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final StringTensor expect = new StringTensor(tensor.toString());
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toString())).getField(0);
Assert.assertEquals(expect, result);
}
Aggregations