Search in sources :

Example 1 with StringTensor

use of com.alibaba.alink.common.linalg.tensor.StringTensor in project Alink by alibaba.

the class TFExampleConversionV2 method javaToFeature.

/**
 * convert java object to tensorflow feature.
 *
 * @param dt  java object data type.
 * @param val given java object.
 * @return tensorflow feature.
 */
public static Feature javaToFeature(DataTypesV2 dt, Object val) {
    Feature.Builder featureBuilder = Feature.newBuilder();
    FloatList.Builder floatListBuilder = FloatList.newBuilder();
    Int64List.Builder int64ListBuilder = Int64List.newBuilder();
    // When dt is TENSOR, find the exact type first.
    if (DataTypesV2.TENSOR.equals(dt)) {
        if (val instanceof FloatTensor) {
            dt = DataTypesV2.FLOAT_TENSOR;
        } else if (val instanceof DoubleTensor) {
            dt = DataTypesV2.DOUBLE_TENSOR;
        } else if (val instanceof IntTensor) {
            dt = DataTypesV2.INT_TENSOR;
        } else if (val instanceof LongTensor) {
            dt = DataTypesV2.LONG_TENSOR;
        } else if (val instanceof BoolTensor) {
            dt = DataTypesV2.BOOLEAN_TENSOR;
        } else if (val instanceof UByteTensor) {
            dt = DataTypesV2.UBYTE_TENSOR;
        } else if (val instanceof StringTensor) {
            dt = DataTypesV2.STRING_TENSOR;
        } else if (val instanceof ByteTensor) {
            dt = DataTypesV2.BYTE_TENSOR;
        }
    }
    switch(dt) {
        case FLOAT_16:
        case FLOAT_32:
        case FLOAT_64:
            {
                floatListBuilder.addValue((Float) val);
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_8:
        case INT_16:
        case INT_32:
        case INT_64:
        case UINT_8:
        case UINT_16:
        case UINT_32:
        case UINT_64:
            {
                int64ListBuilder.addValue(castAsLong(val));
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING:
            {
                BytesList.Builder bb = BytesList.newBuilder();
                bb.addValue(castAsBytes(val));
                featureBuilder.setBytesList(bb);
                break;
            }
        case FLOAT_TENSOR:
            {
                FloatTensor floatTensor = (FloatTensor) val;
                long size = floatTensor.size();
                floatTensor = floatTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue(floatTensor.getFloat(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case DOUBLE_TENSOR:
            {
                DoubleTensor doubleTensor = (DoubleTensor) val;
                long size = doubleTensor.size();
                doubleTensor = doubleTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue((float) doubleTensor.getDouble(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_TENSOR:
            {
                IntTensor intTensor = (IntTensor) val;
                long size = intTensor.size();
                intTensor = intTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(intTensor.getInt(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case LONG_TENSOR:
            {
                LongTensor longTensor = (LongTensor) val;
                long size = longTensor.size();
                longTensor = longTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(longTensor.getLong(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case BOOLEAN_TENSOR:
            {
                BoolTensor boolTensor = (BoolTensor) val;
                long size = boolTensor.size();
                boolTensor = boolTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(boolTensor.getBoolean(i) ? 1 : 0);
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case UBYTE_TENSOR:
            {
                UByteTensor ubyteTensor = (UByteTensor) val;
                long size = ubyteTensor.size();
                ubyteTensor = ubyteTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(ubyteTensor.getUByte(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING_TENSOR:
            {
                StringTensor stringTensor = (StringTensor) val;
                long size = stringTensor.size();
                stringTensor = stringTensor.reshape(new Shape(size));
                BytesList.Builder bb = BytesList.newBuilder();
                for (long i = 0; i < size; i += 1) {
                    bb.addValue(castAsBytes(stringTensor.getString(i)));
                }
                featureBuilder.setBytesList(bb);
                break;
            }
        case BYTE_TENSOR:
        default:
            throw new RuntimeException("Unsupported data type for TF");
    }
    return featureBuilder.build();
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) Feature(org.tensorflow.proto.example.Feature) FloatList(org.tensorflow.proto.example.FloatList) Int64List(org.tensorflow.proto.example.Int64List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) ByteTensor(com.alibaba.alink.common.linalg.tensor.ByteTensor)

Example 2 with StringTensor

use of com.alibaba.alink.common.linalg.tensor.StringTensor in project Alink by alibaba.

the class StringTFTensorConversionImpl method encodeBatchTensor.

@Override
public String[] encodeBatchTensor(Tensor<?> tensor, int batchAxis) {
    long[] shape = tensor.shape().asArray();
    long batchSize = shape[batchAxis];
    if (TString.DTYPE.equals(tensor.dataType())) {
        StringTensor stringTensor = (StringTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        StringTensor[] stringTensors = unstack(stringTensor, batchAxis, null);
        return Arrays.stream(stringTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TString) d.data(), String.class)).map(d -> String.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TBool.DTYPE.equals(tensor.dataType())) {
        BoolTensor boolTensor = (BoolTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        BoolTensor[] boolTensors = unstack(boolTensor, batchAxis, null);
        return Arrays.stream(boolTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TBool) d.data())).map(d -> Booleans.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt32.DTYPE.equals(tensor.dataType())) {
        IntTensor intTensor = (IntTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        IntTensor[] intTensors = unstack(intTensor, batchAxis, null);
        return Arrays.stream(intTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt32) d.data())).map(d -> Ints.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt64.DTYPE.equals(tensor.dataType())) {
        LongTensor longTensor = (LongTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        LongTensor[] longTensors = unstack(longTensor, batchAxis, null);
        return Arrays.stream(longTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt64) d.data())).map(d -> Longs.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat32.DTYPE.equals(tensor.dataType())) {
        FloatTensor floatTensor = (FloatTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        FloatTensor[] floatTensors = unstack(floatTensor, batchAxis, null);
        return Arrays.stream(floatTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat32) d.data())).map(d -> Floats.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat64.DTYPE.equals(tensor.dataType())) {
        DoubleTensor doubleTensor = (DoubleTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        DoubleTensor[] doubleTensors = unstack(doubleTensor, batchAxis, null);
        return Arrays.stream(doubleTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat64) d.data())).map(d -> Doubles.join(SEP_CHAR, d)).toArray(String[]::new);
    }
    throw new UnsupportedOperationException("Unsupported dtype: " + tensor.dataType());
}
Also used : TString(org.tensorflow.types.TString) Booleans(com.google.common.primitives.Booleans) Arrays(java.util.Arrays) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) TFloat32(org.tensorflow.types.TFloat32) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) StdArrays(org.tensorflow.ndarray.StdArrays) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TType(org.tensorflow.types.family.TType) Tensor(org.tensorflow.Tensor) TInt32(org.tensorflow.types.TInt32) Longs(com.google.common.primitives.Longs) TString(org.tensorflow.types.TString) Floats(com.google.common.primitives.Floats) TBool(org.tensorflow.types.TBool) TFloat64(org.tensorflow.types.TFloat64) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Preconditions(org.apache.flink.util.Preconditions) StringUtils(org.apache.flink.util.StringUtils) Ints(com.google.common.primitives.Ints) DataType(org.tensorflow.proto.framework.DataType) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) Doubles(com.google.common.primitives.Doubles) Tensor.unstack(com.alibaba.alink.common.linalg.tensor.Tensor.unstack) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TInt64(org.tensorflow.types.TInt64) Shape(org.tensorflow.ndarray.Shape) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TString(org.tensorflow.types.TString) TFloat32(org.tensorflow.types.TFloat32) TInt32(org.tensorflow.types.TInt32) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor)

Example 3 with StringTensor

use of com.alibaba.alink.common.linalg.tensor.StringTensor in project Alink by alibaba.

the class TensorTFTensorConversionImpl method parseStringTensor.

@Override
public Tensor<TString> parseStringTensor(com.alibaba.alink.common.linalg.tensor.Tensor<?>[] tensors, long[] shape) {
    StringTensor stackedTensor = new StringTensor(new com.alibaba.alink.common.linalg.tensor.Shape(shape));
    stack(castArrayType(tensors, StringTensor[].class), 0, stackedTensor);
    // noinspection unchecked
    return (Tensor<TString>) TFTensorConversionUtils.toTFTensor(stackedTensor);
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor)

Example 4 with StringTensor

use of com.alibaba.alink.common.linalg.tensor.StringTensor in project Alink by alibaba.

the class TensorTFTensorConversionTest method testStringTensor.

@Test
public void testStringTensor() {
    String[][] arr = new String[3][4];
    for (int i = 0; i < 3; i += 1) {
        for (int j = 0; j < 4; j += 1) {
            arr[i][j] = RandomStringUtils.random(6);
        }
    }
    StringTensor data = new StringTensor(arr);
    Tensor<TString> std = TString.tensorOf(StdArrays.ndCopyOf(arr));
    TensorInfo tensorInfo = createTensorInfo(DataType.DT_STRING, new long[] { 3, 4 });
    // noinspection unchecked
    Tensor<TString> tensor = (Tensor<TString>) TensorTFTensorConversionImpl.getInstance().parseTensor(data, tensorInfo);
    Assert.assertArrayEquals(std.shape().asArray(), tensor.shape().asArray());
    Assert.assertArrayEquals(StdArrays.array2dCopyOf(std.data(), String.class), StdArrays.array2dCopyOf(tensor.data(), String.class));
    StringTensor encoded = (StringTensor) TensorTFTensorConversionImpl.getInstance().encodeTensor(tensor);
    for (int i = 0; i < 3; i += 1) {
        for (int j = 0; j < 4; j += 1) {
            Assert.assertEquals(encoded.getString(i, j), data.getString(i, j));
        }
    }
}
Also used : TString(org.tensorflow.types.TString) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) TString(org.tensorflow.types.TString) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Test(org.junit.Test)

Example 5 with StringTensor

use of com.alibaba.alink.common.linalg.tensor.StringTensor in project Alink by alibaba.

the class KerasSequentialClassifierBatchOpTest method testTFHubLayerStringTensor.

@Category(DLTest.class)
@Test
public void testTFHubLayerStringTensor() throws Exception {
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    Random random = new Random();
    int n = 1000;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < n; i += 1) {
        int length = random.nextInt(8) + 1;
        String arr = RandomStringUtils.randomAlphanumeric(length);
        int label = random.nextInt(2);
        rows.add(Row.of(new StringTensor(arr), label));
    }
    BatchOperator<?> source = new MemSourceBatchOp(rows, "tensor STRING_TENSOR, label int");
    KerasSequentialClassifierTrainBatchOp trainBatchOp = new KerasSequentialClassifierTrainBatchOp().setTensorCol("tensor").setLabelCol("label").setLayers("hub.KerasLayer('https://tfhub.dev/google/nnlm-de-dim50/2', input_shape=[], dtype=tf.string)", "Flatten()").setCheckpointFilePath(PythonFileUtils.createTempDir("keras_sequential_train_").toString()).setNumEpochs(1).linkFrom(source);
    KerasSequentialClassifierPredictBatchOp predictBatchOp = new KerasSequentialClassifierPredictBatchOp().setPredictionCol("pred").setPredictionDetailCol("pred_detail").setReservedCols("label").linkFrom(trainBatchOp, source);
    predictBatchOp.lazyPrint(10);
    BatchOperator.execute();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) Random(java.util.Random) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) ArrayList(java.util.ArrayList) Row(org.apache.flink.types.Row) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)6 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)5 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)5 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)4 Test (org.junit.Test)3 Tensor (org.tensorflow.Tensor)3 TensorInfo (org.tensorflow.proto.framework.TensorInfo)2 TString (org.tensorflow.types.TString)2 TF2TensorUtils (com.alibaba.alink.common.dl.utils.TF2TensorUtils)1 TFTensorConversionUtils (com.alibaba.alink.common.dl.utils.TFTensorConversionUtils)1 ByteTensor (com.alibaba.alink.common.linalg.tensor.ByteTensor)1 Shape (com.alibaba.alink.common.linalg.tensor.Shape)1 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)1 Tensor.unstack (com.alibaba.alink.common.linalg.tensor.Tensor.unstack)1 UByteTensor (com.alibaba.alink.common.linalg.tensor.UByteTensor)1 Mapper (com.alibaba.alink.common.mapper.Mapper)1 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)1 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)1