use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.
the class TFExampleConversionV2 method javaToFeature.
/**
* convert java object to tensorflow feature.
*
* @param dt java object data type.
* @param val given java object.
* @return tensorflow feature.
*/
public static Feature javaToFeature(DataTypesV2 dt, Object val) {
Feature.Builder featureBuilder = Feature.newBuilder();
FloatList.Builder floatListBuilder = FloatList.newBuilder();
Int64List.Builder int64ListBuilder = Int64List.newBuilder();
// When dt is TENSOR, find the exact type first.
if (DataTypesV2.TENSOR.equals(dt)) {
if (val instanceof FloatTensor) {
dt = DataTypesV2.FLOAT_TENSOR;
} else if (val instanceof DoubleTensor) {
dt = DataTypesV2.DOUBLE_TENSOR;
} else if (val instanceof IntTensor) {
dt = DataTypesV2.INT_TENSOR;
} else if (val instanceof LongTensor) {
dt = DataTypesV2.LONG_TENSOR;
} else if (val instanceof BoolTensor) {
dt = DataTypesV2.BOOLEAN_TENSOR;
} else if (val instanceof UByteTensor) {
dt = DataTypesV2.UBYTE_TENSOR;
} else if (val instanceof StringTensor) {
dt = DataTypesV2.STRING_TENSOR;
} else if (val instanceof ByteTensor) {
dt = DataTypesV2.BYTE_TENSOR;
}
}
switch(dt) {
case FLOAT_16:
case FLOAT_32:
case FLOAT_64:
{
floatListBuilder.addValue((Float) val);
featureBuilder.setFloatList(floatListBuilder);
break;
}
case INT_8:
case INT_16:
case INT_32:
case INT_64:
case UINT_8:
case UINT_16:
case UINT_32:
case UINT_64:
{
int64ListBuilder.addValue(castAsLong(val));
featureBuilder.setInt64List(int64ListBuilder);
break;
}
case STRING:
{
BytesList.Builder bb = BytesList.newBuilder();
bb.addValue(castAsBytes(val));
featureBuilder.setBytesList(bb);
break;
}
case FLOAT_TENSOR:
{
FloatTensor floatTensor = (FloatTensor) val;
long size = floatTensor.size();
floatTensor = floatTensor.reshape(new Shape(size));
for (long i = 0; i < size; i += 1) {
floatListBuilder.addValue(floatTensor.getFloat(i));
}
featureBuilder.setFloatList(floatListBuilder);
break;
}
case DOUBLE_TENSOR:
{
DoubleTensor doubleTensor = (DoubleTensor) val;
long size = doubleTensor.size();
doubleTensor = doubleTensor.reshape(new Shape(size));
for (long i = 0; i < size; i += 1) {
floatListBuilder.addValue((float) doubleTensor.getDouble(i));
}
featureBuilder.setFloatList(floatListBuilder);
break;
}
case INT_TENSOR:
{
IntTensor intTensor = (IntTensor) val;
long size = intTensor.size();
intTensor = intTensor.reshape(new Shape(size));
for (long i = 0; i < size; i += 1) {
int64ListBuilder.addValue(intTensor.getInt(i));
}
featureBuilder.setInt64List(int64ListBuilder);
break;
}
case LONG_TENSOR:
{
LongTensor longTensor = (LongTensor) val;
long size = longTensor.size();
longTensor = longTensor.reshape(new Shape(size));
for (long i = 0; i < size; i += 1) {
int64ListBuilder.addValue(longTensor.getLong(i));
}
featureBuilder.setInt64List(int64ListBuilder);
break;
}
case BOOLEAN_TENSOR:
{
BoolTensor boolTensor = (BoolTensor) val;
long size = boolTensor.size();
boolTensor = boolTensor.reshape(new Shape(size));
for (long i = 0; i < size; i += 1) {
int64ListBuilder.addValue(boolTensor.getBoolean(i) ? 1 : 0);
}
featureBuilder.setInt64List(int64ListBuilder);
break;
}
case UBYTE_TENSOR:
{
UByteTensor ubyteTensor = (UByteTensor) val;
long size = ubyteTensor.size();
ubyteTensor = ubyteTensor.reshape(new Shape(size));
for (long i = 0; i < size; i += 1) {
int64ListBuilder.addValue(ubyteTensor.getUByte(i));
}
featureBuilder.setInt64List(int64ListBuilder);
break;
}
case STRING_TENSOR:
{
StringTensor stringTensor = (StringTensor) val;
long size = stringTensor.size();
stringTensor = stringTensor.reshape(new Shape(size));
BytesList.Builder bb = BytesList.newBuilder();
for (long i = 0; i < size; i += 1) {
bb.addValue(castAsBytes(stringTensor.getString(i)));
}
featureBuilder.setBytesList(bb);
break;
}
case BYTE_TENSOR:
default:
throw new RuntimeException("Unsupported data type for TF");
}
return featureBuilder.build();
}
use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.
the class VectorToTensorMapperTest method testFloatType.
@Test
public void testFloatType() throws Exception {
final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(VectorToTensorParams.TENSOR_DATA_TYPE, DataType.FLOAT));
Assert.assertEquals(TensorTypes.FLOAT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.
the class VectorToTensorMapperTest method testReshape.
@Test
public void testReshape() throws Exception {
final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.
the class VectorToTensorMapperTest method testOp.
@Test
public void testOp() throws Exception {
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
Row[] rows = new Row[] { Row.of(tensor.toVector()) };
MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(rows, new String[] { "vec" });
memSourceBatchOp.link(new VectorToTensorBatchOp().setSelectedCol("vec").setTensorShape(2, 3).setTensorDataType("float")).lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> rows) {
Assert.assertEquals(expect, TensorUtil.getTensor(rows.get(0).getField(0)));
}
});
BatchOperator.execute();
}
use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.
the class ToTensorMapperTest method testDefaultType.
@Test
public void testDefaultType() throws Exception {
final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
Assert.assertEquals(TensorTypes.TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
Assert.assertEquals(expect, result);
}
Aggregations