use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.
the class ToTensorMapperTest method testDefaultType.
@Test
public void testDefaultType() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
Assert.assertEquals(TensorTypes.TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.
the class ToTensorMapperTest method testOp.
@Test
public void testOp() throws Exception {
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
Row[] rows = new Row[] { Row.of(tensor.toVector()) };
MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(rows, new String[] { "vec" });
memSourceBatchOp.link(new ToTensorBatchOp().setSelectedCol("vec").setTensorShape(2, 3).setTensorDataType("float")).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
Assert.assertEquals(expect, TensorUtil.getTensor(rows.get(0).getField(0)));
}
});
BatchOperator.execute();
}
use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.
the class ToTensorMapperTest method testHandleInvalidSkip.
@Test
public void testHandleInvalidSkip() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(ToTensorParams.TENSOR_DATA_TYPE, DataType.INT).set(ToTensorParams.HANDLE_INVALID, HandleInvalidMethod.SKIP));
Assert.assertEquals(TensorTypes.INT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertNull(result);
}
use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.
the class ToTensorMapperTest method testHandleInvalidError.
@Test(expected = IllegalArgumentException.class)
public void testHandleInvalidError() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(ToTensorParams.TENSOR_DATA_TYPE, DataType.INT));
Assert.assertEquals(TensorTypes.INT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
mapper.map(Row.of(tensor.toVector())).getField(0);
}
use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.
the class ToTensorMapperTest method testStringType.
@Test
public void testStringType() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "str" }, new TypeInformation<?>[] { Types.STRING }), new Params().set(ToTensorParams.SELECTED_COL, "str").set(ToTensorParams.TENSOR_DATA_TYPE, DataType.STRING));
Assert.assertEquals(TensorTypes.STRING_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final StringTensor expect = new StringTensor(tensor.toString());
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toString())).getField(0);
Assert.assertEquals(expect, result);
}
Aggregations