Search in sources :

Example 1 with BoolTensor

use of com.alibaba.alink.common.linalg.tensor.BoolTensor in project Alink by alibaba.

the class TFExampleConversionV2 method javaToFeature.

/**
 * convert java object to tensorflow feature.
 *
 * @param dt  java object data type.
 * @param val given java object.
 * @return tensorflow feature.
 */
public static Feature javaToFeature(DataTypesV2 dt, Object val) {
    Feature.Builder featureBuilder = Feature.newBuilder();
    FloatList.Builder floatListBuilder = FloatList.newBuilder();
    Int64List.Builder int64ListBuilder = Int64List.newBuilder();
    // When dt is TENSOR, find the exact type first.
    if (DataTypesV2.TENSOR.equals(dt)) {
        if (val instanceof FloatTensor) {
            dt = DataTypesV2.FLOAT_TENSOR;
        } else if (val instanceof DoubleTensor) {
            dt = DataTypesV2.DOUBLE_TENSOR;
        } else if (val instanceof IntTensor) {
            dt = DataTypesV2.INT_TENSOR;
        } else if (val instanceof LongTensor) {
            dt = DataTypesV2.LONG_TENSOR;
        } else if (val instanceof BoolTensor) {
            dt = DataTypesV2.BOOLEAN_TENSOR;
        } else if (val instanceof UByteTensor) {
            dt = DataTypesV2.UBYTE_TENSOR;
        } else if (val instanceof StringTensor) {
            dt = DataTypesV2.STRING_TENSOR;
        } else if (val instanceof ByteTensor) {
            dt = DataTypesV2.BYTE_TENSOR;
        }
    }
    switch(dt) {
        case FLOAT_16:
        case FLOAT_32:
        case FLOAT_64:
            {
                floatListBuilder.addValue((Float) val);
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_8:
        case INT_16:
        case INT_32:
        case INT_64:
        case UINT_8:
        case UINT_16:
        case UINT_32:
        case UINT_64:
            {
                int64ListBuilder.addValue(castAsLong(val));
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING:
            {
                BytesList.Builder bb = BytesList.newBuilder();
                bb.addValue(castAsBytes(val));
                featureBuilder.setBytesList(bb);
                break;
            }
        case FLOAT_TENSOR:
            {
                FloatTensor floatTensor = (FloatTensor) val;
                long size = floatTensor.size();
                floatTensor = floatTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue(floatTensor.getFloat(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case DOUBLE_TENSOR:
            {
                DoubleTensor doubleTensor = (DoubleTensor) val;
                long size = doubleTensor.size();
                doubleTensor = doubleTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue((float) doubleTensor.getDouble(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_TENSOR:
            {
                IntTensor intTensor = (IntTensor) val;
                long size = intTensor.size();
                intTensor = intTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(intTensor.getInt(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case LONG_TENSOR:
            {
                LongTensor longTensor = (LongTensor) val;
                long size = longTensor.size();
                longTensor = longTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(longTensor.getLong(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case BOOLEAN_TENSOR:
            {
                BoolTensor boolTensor = (BoolTensor) val;
                long size = boolTensor.size();
                boolTensor = boolTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(boolTensor.getBoolean(i) ? 1 : 0);
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case UBYTE_TENSOR:
            {
                UByteTensor ubyteTensor = (UByteTensor) val;
                long size = ubyteTensor.size();
                ubyteTensor = ubyteTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(ubyteTensor.getUByte(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING_TENSOR:
            {
                StringTensor stringTensor = (StringTensor) val;
                long size = stringTensor.size();
                stringTensor = stringTensor.reshape(new Shape(size));
                BytesList.Builder bb = BytesList.newBuilder();
                for (long i = 0; i < size; i += 1) {
                    bb.addValue(castAsBytes(stringTensor.getString(i)));
                }
                featureBuilder.setBytesList(bb);
                break;
            }
        case BYTE_TENSOR:
        default:
            throw new RuntimeException("Unsupported data type for TF");
    }
    return featureBuilder.build();
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) Feature(org.tensorflow.proto.example.Feature) FloatList(org.tensorflow.proto.example.FloatList) Int64List(org.tensorflow.proto.example.Int64List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) ByteTensor(com.alibaba.alink.common.linalg.tensor.ByteTensor)

Example 2 with BoolTensor

use of com.alibaba.alink.common.linalg.tensor.BoolTensor in project Alink by alibaba.

the class StringTFTensorConversionImpl method encodeBatchTensor.

@Override
public String[] encodeBatchTensor(Tensor<?> tensor, int batchAxis) {
    long[] shape = tensor.shape().asArray();
    long batchSize = shape[batchAxis];
    if (TString.DTYPE.equals(tensor.dataType())) {
        StringTensor stringTensor = (StringTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        StringTensor[] stringTensors = unstack(stringTensor, batchAxis, null);
        return Arrays.stream(stringTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TString) d.data(), String.class)).map(d -> String.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TBool.DTYPE.equals(tensor.dataType())) {
        BoolTensor boolTensor = (BoolTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        BoolTensor[] boolTensors = unstack(boolTensor, batchAxis, null);
        return Arrays.stream(boolTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TBool) d.data())).map(d -> Booleans.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt32.DTYPE.equals(tensor.dataType())) {
        IntTensor intTensor = (IntTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        IntTensor[] intTensors = unstack(intTensor, batchAxis, null);
        return Arrays.stream(intTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt32) d.data())).map(d -> Ints.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt64.DTYPE.equals(tensor.dataType())) {
        LongTensor longTensor = (LongTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        LongTensor[] longTensors = unstack(longTensor, batchAxis, null);
        return Arrays.stream(longTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt64) d.data())).map(d -> Longs.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat32.DTYPE.equals(tensor.dataType())) {
        FloatTensor floatTensor = (FloatTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        FloatTensor[] floatTensors = unstack(floatTensor, batchAxis, null);
        return Arrays.stream(floatTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat32) d.data())).map(d -> Floats.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat64.DTYPE.equals(tensor.dataType())) {
        DoubleTensor doubleTensor = (DoubleTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        DoubleTensor[] doubleTensors = unstack(doubleTensor, batchAxis, null);
        return Arrays.stream(doubleTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat64) d.data())).map(d -> Doubles.join(SEP_CHAR, d)).toArray(String[]::new);
    }
    throw new UnsupportedOperationException("Unsupported dtype: " + tensor.dataType());
}
Also used : TString(org.tensorflow.types.TString) Booleans(com.google.common.primitives.Booleans) Arrays(java.util.Arrays) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) TFloat32(org.tensorflow.types.TFloat32) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) StdArrays(org.tensorflow.ndarray.StdArrays) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TType(org.tensorflow.types.family.TType) Tensor(org.tensorflow.Tensor) TInt32(org.tensorflow.types.TInt32) Longs(com.google.common.primitives.Longs) TString(org.tensorflow.types.TString) Floats(com.google.common.primitives.Floats) TBool(org.tensorflow.types.TBool) TFloat64(org.tensorflow.types.TFloat64) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Preconditions(org.apache.flink.util.Preconditions) StringUtils(org.apache.flink.util.StringUtils) Ints(com.google.common.primitives.Ints) DataType(org.tensorflow.proto.framework.DataType) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) Doubles(com.google.common.primitives.Doubles) Tensor.unstack(com.alibaba.alink.common.linalg.tensor.Tensor.unstack) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TInt64(org.tensorflow.types.TInt64) Shape(org.tensorflow.ndarray.Shape) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TString(org.tensorflow.types.TString) TFloat32(org.tensorflow.types.TFloat32) TInt32(org.tensorflow.types.TInt32) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor)

Example 3 with BoolTensor

use of com.alibaba.alink.common.linalg.tensor.BoolTensor in project Alink by alibaba.

the class TensorTFTensorConversionImpl method parseBoolTensor.

@Override
public Tensor<TBool> parseBoolTensor(com.alibaba.alink.common.linalg.tensor.Tensor<?>[] tensors, long[] shape) {
    BoolTensor stackedTensor = new BoolTensor(new com.alibaba.alink.common.linalg.tensor.Shape(shape));
    stack(castArrayType(tensors, BoolTensor[].class), 0, stackedTensor);
    // noinspection unchecked
    return (Tensor<TBool>) TFTensorConversionUtils.toTFTensor(stackedTensor);
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor)

Example 4 with BoolTensor

use of com.alibaba.alink.common.linalg.tensor.BoolTensor in project Alink by alibaba.

the class TensorTFTensorConversionTest method testBoolTensor.

@Test
public void testBoolTensor() {
    Random random = new Random(2021);
    boolean[][] arr = new boolean[3][4];
    for (int i = 0; i < 3; i += 1) {
        for (int j = 0; j < 4; j += 1) {
            arr[i][j] = random.nextBoolean();
        }
    }
    BoolTensor data = new BoolTensor(arr);
    Tensor<TBool> std = TBool.tensorOf(StdArrays.ndCopyOf(arr));
    TensorInfo tensorInfo = createTensorInfo(DataType.DT_BOOL, new long[] { 3, 4 });
    // noinspection unchecked
    Tensor<TBool> tensor = (Tensor<TBool>) TensorTFTensorConversionImpl.getInstance().parseTensor(data, tensorInfo);
    Assert.assertArrayEquals(std.shape().asArray(), tensor.shape().asArray());
    Assert.assertArrayEquals(StdArrays.array2dCopyOf(std.data()), StdArrays.array2dCopyOf(tensor.data()));
    BoolTensor encoded = (BoolTensor) TensorTFTensorConversionImpl.getInstance().encodeTensor(tensor);
    for (int i = 0; i < 3; i += 1) {
        for (int j = 0; j < 4; j += 1) {
            Assert.assertEquals(encoded.getBoolean(i, j), data.getBoolean(i, j));
        }
    }
}
Also used : TBool(org.tensorflow.types.TBool) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) Random(java.util.Random) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Test(org.junit.Test)

Aggregations

BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)4 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)4 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)4 Tensor (org.tensorflow.Tensor)3 TensorInfo (org.tensorflow.proto.framework.TensorInfo)2 TBool (org.tensorflow.types.TBool)2 TF2TensorUtils (com.alibaba.alink.common.dl.utils.TF2TensorUtils)1 TFTensorConversionUtils (com.alibaba.alink.common.dl.utils.TFTensorConversionUtils)1 ByteTensor (com.alibaba.alink.common.linalg.tensor.ByteTensor)1 Shape (com.alibaba.alink.common.linalg.tensor.Shape)1 Tensor.unstack (com.alibaba.alink.common.linalg.tensor.Tensor.unstack)1 UByteTensor (com.alibaba.alink.common.linalg.tensor.UByteTensor)1 Booleans (com.google.common.primitives.Booleans)1 Doubles (com.google.common.primitives.Doubles)1 Floats (com.google.common.primitives.Floats)1 Ints (com.google.common.primitives.Ints)1 Longs (com.google.common.primitives.Longs)1