Search in sources :

Example 1 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class TFExampleConversionV2 method javaToFeature.

/**
 * convert java object to tensorflow feature.
 *
 * @param dt  java object data type.
 * @param val given java object.
 * @return tensorflow feature.
 */
public static Feature javaToFeature(DataTypesV2 dt, Object val) {
    Feature.Builder featureBuilder = Feature.newBuilder();
    FloatList.Builder floatListBuilder = FloatList.newBuilder();
    Int64List.Builder int64ListBuilder = Int64List.newBuilder();
    // When dt is TENSOR, find the exact type first.
    if (DataTypesV2.TENSOR.equals(dt)) {
        if (val instanceof FloatTensor) {
            dt = DataTypesV2.FLOAT_TENSOR;
        } else if (val instanceof DoubleTensor) {
            dt = DataTypesV2.DOUBLE_TENSOR;
        } else if (val instanceof IntTensor) {
            dt = DataTypesV2.INT_TENSOR;
        } else if (val instanceof LongTensor) {
            dt = DataTypesV2.LONG_TENSOR;
        } else if (val instanceof BoolTensor) {
            dt = DataTypesV2.BOOLEAN_TENSOR;
        } else if (val instanceof UByteTensor) {
            dt = DataTypesV2.UBYTE_TENSOR;
        } else if (val instanceof StringTensor) {
            dt = DataTypesV2.STRING_TENSOR;
        } else if (val instanceof ByteTensor) {
            dt = DataTypesV2.BYTE_TENSOR;
        }
    }
    switch(dt) {
        case FLOAT_16:
        case FLOAT_32:
        case FLOAT_64:
            {
                floatListBuilder.addValue((Float) val);
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_8:
        case INT_16:
        case INT_32:
        case INT_64:
        case UINT_8:
        case UINT_16:
        case UINT_32:
        case UINT_64:
            {
                int64ListBuilder.addValue(castAsLong(val));
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING:
            {
                BytesList.Builder bb = BytesList.newBuilder();
                bb.addValue(castAsBytes(val));
                featureBuilder.setBytesList(bb);
                break;
            }
        case FLOAT_TENSOR:
            {
                FloatTensor floatTensor = (FloatTensor) val;
                long size = floatTensor.size();
                floatTensor = floatTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue(floatTensor.getFloat(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case DOUBLE_TENSOR:
            {
                DoubleTensor doubleTensor = (DoubleTensor) val;
                long size = doubleTensor.size();
                doubleTensor = doubleTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue((float) doubleTensor.getDouble(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_TENSOR:
            {
                IntTensor intTensor = (IntTensor) val;
                long size = intTensor.size();
                intTensor = intTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(intTensor.getInt(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case LONG_TENSOR:
            {
                LongTensor longTensor = (LongTensor) val;
                long size = longTensor.size();
                longTensor = longTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(longTensor.getLong(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case BOOLEAN_TENSOR:
            {
                BoolTensor boolTensor = (BoolTensor) val;
                long size = boolTensor.size();
                boolTensor = boolTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(boolTensor.getBoolean(i) ? 1 : 0);
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case UBYTE_TENSOR:
            {
                UByteTensor ubyteTensor = (UByteTensor) val;
                long size = ubyteTensor.size();
                ubyteTensor = ubyteTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(ubyteTensor.getUByte(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING_TENSOR:
            {
                StringTensor stringTensor = (StringTensor) val;
                long size = stringTensor.size();
                stringTensor = stringTensor.reshape(new Shape(size));
                BytesList.Builder bb = BytesList.newBuilder();
                for (long i = 0; i < size; i += 1) {
                    bb.addValue(castAsBytes(stringTensor.getString(i)));
                }
                featureBuilder.setBytesList(bb);
                break;
            }
        case BYTE_TENSOR:
        default:
            throw new RuntimeException("Unsupported data type for TF");
    }
    return featureBuilder.build();
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) Feature(org.tensorflow.proto.example.Feature) FloatList(org.tensorflow.proto.example.FloatList) Int64List(org.tensorflow.proto.example.Int64List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) ByteTensor(com.alibaba.alink.common.linalg.tensor.ByteTensor)

Example 2 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class VectorToTensorMapperTest method testFloatType.

@Test
public void testFloatType() throws Exception {
    final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }).set(VectorToTensorParams.TENSOR_DATA_TYPE, DataType.FLOAT));
    Assert.assertEquals(TensorTypes.FLOAT_TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) VectorToTensorParams(com.alibaba.alink.params.dataproc.VectorToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 3 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class VectorToTensorMapperTest method testReshape.

@Test
public void testReshape() throws Exception {
    final Mapper mapper = new VectorToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(VectorToTensorParams.SELECTED_COL, "vec").set(VectorToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) VectorToTensorParams(com.alibaba.alink.params.dataproc.VectorToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Example 4 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class VectorToTensorMapperTest method testOp.

@Test
public void testOp() throws Exception {
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
    Row[] rows = new Row[] { Row.of(tensor.toVector()) };
    MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(rows, new String[] { "vec" });
    memSourceBatchOp.link(new VectorToTensorBatchOp().setSelectedCol("vec").setTensorShape(2, 3).setTensorDataType("float")).lazyCollect(new Consumer<List<Row>>() {

        @Override
        public void accept(List<Row> rows) {
            Assert.assertEquals(expect, TensorUtil.getTensor(rows.get(0).getField(0)));
        }
    });
    BatchOperator.execute();
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) List(java.util.List) Row(org.apache.flink.types.Row) VectorToTensorBatchOp(com.alibaba.alink.operator.batch.dataproc.VectorToTensorBatchOp) Test(org.junit.Test)

Example 5 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class ToTensorMapperTest method testDefaultType.

@Test
public void testDefaultType() throws Exception {
    final Mapper mapper = new ToTensorMapper(new TableSchema(new String[] { "vec" }, new TypeInformation<?>[] { VectorTypes.DENSE_VECTOR }), new Params().set(ToTensorParams.SELECTED_COL, "vec").set(ToTensorParams.TENSOR_SHAPE, new Long[] { 2L, 3L }));
    Assert.assertEquals(TensorTypes.TENSOR, mapper.getOutputSchema().getFieldTypes()[0]);
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final DoubleTensor expect = tensor.reshape(new Shape(2L, 3L));
    final Tensor<?> result = (Tensor<?>) mapper.map(Row.of(tensor.toVector())).getField(0);
    Assert.assertEquals(expect, result);
}
Also used : Mapper(com.alibaba.alink.common.mapper.Mapper) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(com.alibaba.alink.common.linalg.tensor.Tensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TableSchema(org.apache.flink.table.api.TableSchema) ToTensorParams(com.alibaba.alink.params.dataproc.ToTensorParams) Params(org.apache.flink.ml.api.misc.param.Params) TypeInformation(org.apache.flink.api.common.typeinfo.TypeInformation) Test(org.junit.Test)

Aggregations

FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)13 Shape (com.alibaba.alink.common.linalg.tensor.Shape)13 Test (org.junit.Test)12 Params (org.apache.flink.ml.api.misc.param.Params)10 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)8 Row (org.apache.flink.types.Row)7 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)6 TableSchema (org.apache.flink.table.api.TableSchema)6 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)5 Mapper (com.alibaba.alink.common.mapper.Mapper)5 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)5 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)4 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)4 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)4 DLTest (com.alibaba.alink.testutil.categories.DLTest)4 ArrayList (java.util.ArrayList)4 Category (org.junit.experimental.categories.Category)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)3 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)3 File (java.io.File)3