Search in sources :

Example 1 with TFTensorConversionUtils

use of com.alibaba.alink.common.dl.utils.TFTensorConversionUtils in project Alink by alibaba.

the class StringTFTensorConversionImpl method encodeBatchTensor.

@Override
public String[] encodeBatchTensor(Tensor<?> tensor, int batchAxis) {
    long[] shape = tensor.shape().asArray();
    long batchSize = shape[batchAxis];
    if (TString.DTYPE.equals(tensor.dataType())) {
        StringTensor stringTensor = (StringTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        StringTensor[] stringTensors = unstack(stringTensor, batchAxis, null);
        return Arrays.stream(stringTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TString) d.data(), String.class)).map(d -> String.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TBool.DTYPE.equals(tensor.dataType())) {
        BoolTensor boolTensor = (BoolTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        BoolTensor[] boolTensors = unstack(boolTensor, batchAxis, null);
        return Arrays.stream(boolTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TBool) d.data())).map(d -> Booleans.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt32.DTYPE.equals(tensor.dataType())) {
        IntTensor intTensor = (IntTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        IntTensor[] intTensors = unstack(intTensor, batchAxis, null);
        return Arrays.stream(intTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt32) d.data())).map(d -> Ints.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt64.DTYPE.equals(tensor.dataType())) {
        LongTensor longTensor = (LongTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        LongTensor[] longTensors = unstack(longTensor, batchAxis, null);
        return Arrays.stream(longTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt64) d.data())).map(d -> Longs.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat32.DTYPE.equals(tensor.dataType())) {
        FloatTensor floatTensor = (FloatTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        FloatTensor[] floatTensors = unstack(floatTensor, batchAxis, null);
        return Arrays.stream(floatTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat32) d.data())).map(d -> Floats.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat64.DTYPE.equals(tensor.dataType())) {
        DoubleTensor doubleTensor = (DoubleTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        DoubleTensor[] doubleTensors = unstack(doubleTensor, batchAxis, null);
        return Arrays.stream(doubleTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat64) d.data())).map(d -> Doubles.join(SEP_CHAR, d)).toArray(String[]::new);
    }
    throw new UnsupportedOperationException("Unsupported dtype: " + tensor.dataType());
}
Also used : TString(org.tensorflow.types.TString) Booleans(com.google.common.primitives.Booleans) Arrays(java.util.Arrays) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) TFloat32(org.tensorflow.types.TFloat32) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) StdArrays(org.tensorflow.ndarray.StdArrays) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TType(org.tensorflow.types.family.TType) Tensor(org.tensorflow.Tensor) TInt32(org.tensorflow.types.TInt32) Longs(com.google.common.primitives.Longs) TString(org.tensorflow.types.TString) Floats(com.google.common.primitives.Floats) TBool(org.tensorflow.types.TBool) TFloat64(org.tensorflow.types.TFloat64) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Preconditions(org.apache.flink.util.Preconditions) StringUtils(org.apache.flink.util.StringUtils) Ints(com.google.common.primitives.Ints) DataType(org.tensorflow.proto.framework.DataType) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) Doubles(com.google.common.primitives.Doubles) Tensor.unstack(com.alibaba.alink.common.linalg.tensor.Tensor.unstack) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TInt64(org.tensorflow.types.TInt64) Shape(org.tensorflow.ndarray.Shape) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TString(org.tensorflow.types.TString) TFloat32(org.tensorflow.types.TFloat32) TInt32(org.tensorflow.types.TInt32) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor)

Aggregations

TF2TensorUtils (com.alibaba.alink.common.dl.utils.TF2TensorUtils)1 TFTensorConversionUtils (com.alibaba.alink.common.dl.utils.TFTensorConversionUtils)1 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)1 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)1 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)1 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)1 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)1 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)1 Tensor.unstack (com.alibaba.alink.common.linalg.tensor.Tensor.unstack)1 Booleans (com.google.common.primitives.Booleans)1 Doubles (com.google.common.primitives.Doubles)1 Floats (com.google.common.primitives.Floats)1 Ints (com.google.common.primitives.Ints)1 Longs (com.google.common.primitives.Longs)1 StandardCharsets (java.nio.charset.StandardCharsets)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Preconditions (org.apache.flink.util.Preconditions)1 StringUtils (org.apache.flink.util.StringUtils)1 Tensor (org.tensorflow.Tensor)1