use of com.alibaba.alink.common.dl.utils.TF2TensorUtils in project Alink by alibaba.
the class StringTFTensorConversionImpl method encodeBatchTensor.
@Override
public String[] encodeBatchTensor(Tensor<?> tensor, int batchAxis) {
long[] shape = tensor.shape().asArray();
long batchSize = shape[batchAxis];
if (TString.DTYPE.equals(tensor.dataType())) {
StringTensor stringTensor = (StringTensor) TFTensorConversionUtils.fromTFTensor(tensor);
StringTensor[] stringTensors = unstack(stringTensor, batchAxis, null);
return Arrays.stream(stringTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TString) d.data(), String.class)).map(d -> String.join(SEP_CHAR, d)).toArray(String[]::new);
} else if (TBool.DTYPE.equals(tensor.dataType())) {
BoolTensor boolTensor = (BoolTensor) TFTensorConversionUtils.fromTFTensor(tensor);
BoolTensor[] boolTensors = unstack(boolTensor, batchAxis, null);
return Arrays.stream(boolTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TBool) d.data())).map(d -> Booleans.join(SEP_CHAR, d)).toArray(String[]::new);
} else if (TInt32.DTYPE.equals(tensor.dataType())) {
IntTensor intTensor = (IntTensor) TFTensorConversionUtils.fromTFTensor(tensor);
IntTensor[] intTensors = unstack(intTensor, batchAxis, null);
return Arrays.stream(intTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt32) d.data())).map(d -> Ints.join(SEP_CHAR, d)).toArray(String[]::new);
} else if (TInt64.DTYPE.equals(tensor.dataType())) {
LongTensor longTensor = (LongTensor) TFTensorConversionUtils.fromTFTensor(tensor);
LongTensor[] longTensors = unstack(longTensor, batchAxis, null);
return Arrays.stream(longTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt64) d.data())).map(d -> Longs.join(SEP_CHAR, d)).toArray(String[]::new);
} else if (TFloat32.DTYPE.equals(tensor.dataType())) {
FloatTensor floatTensor = (FloatTensor) TFTensorConversionUtils.fromTFTensor(tensor);
FloatTensor[] floatTensors = unstack(floatTensor, batchAxis, null);
return Arrays.stream(floatTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat32) d.data())).map(d -> Floats.join(SEP_CHAR, d)).toArray(String[]::new);
} else if (TFloat64.DTYPE.equals(tensor.dataType())) {
DoubleTensor doubleTensor = (DoubleTensor) TFTensorConversionUtils.fromTFTensor(tensor);
DoubleTensor[] doubleTensors = unstack(doubleTensor, batchAxis, null);
return Arrays.stream(doubleTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat64) d.data())).map(d -> Doubles.join(SEP_CHAR, d)).toArray(String[]::new);
}
throw new UnsupportedOperationException("Unsupported dtype: " + tensor.dataType());
}
Aggregations