Search in sources :

Example 1 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class TFExampleConversionV2 method javaToFeature.

/**
 * convert java object to tensorflow feature.
 *
 * @param dt  java object data type.
 * @param val given java object.
 * @return tensorflow feature.
 */
public static Feature javaToFeature(DataTypesV2 dt, Object val) {
    Feature.Builder featureBuilder = Feature.newBuilder();
    FloatList.Builder floatListBuilder = FloatList.newBuilder();
    Int64List.Builder int64ListBuilder = Int64List.newBuilder();
    // When dt is TENSOR, find the exact type first.
    if (DataTypesV2.TENSOR.equals(dt)) {
        if (val instanceof FloatTensor) {
            dt = DataTypesV2.FLOAT_TENSOR;
        } else if (val instanceof DoubleTensor) {
            dt = DataTypesV2.DOUBLE_TENSOR;
        } else if (val instanceof IntTensor) {
            dt = DataTypesV2.INT_TENSOR;
        } else if (val instanceof LongTensor) {
            dt = DataTypesV2.LONG_TENSOR;
        } else if (val instanceof BoolTensor) {
            dt = DataTypesV2.BOOLEAN_TENSOR;
        } else if (val instanceof UByteTensor) {
            dt = DataTypesV2.UBYTE_TENSOR;
        } else if (val instanceof StringTensor) {
            dt = DataTypesV2.STRING_TENSOR;
        } else if (val instanceof ByteTensor) {
            dt = DataTypesV2.BYTE_TENSOR;
        }
    }
    switch(dt) {
        case FLOAT_16:
        case FLOAT_32:
        case FLOAT_64:
            {
                floatListBuilder.addValue((Float) val);
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_8:
        case INT_16:
        case INT_32:
        case INT_64:
        case UINT_8:
        case UINT_16:
        case UINT_32:
        case UINT_64:
            {
                int64ListBuilder.addValue(castAsLong(val));
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING:
            {
                BytesList.Builder bb = BytesList.newBuilder();
                bb.addValue(castAsBytes(val));
                featureBuilder.setBytesList(bb);
                break;
            }
        case FLOAT_TENSOR:
            {
                FloatTensor floatTensor = (FloatTensor) val;
                long size = floatTensor.size();
                floatTensor = floatTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue(floatTensor.getFloat(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case DOUBLE_TENSOR:
            {
                DoubleTensor doubleTensor = (DoubleTensor) val;
                long size = doubleTensor.size();
                doubleTensor = doubleTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue((float) doubleTensor.getDouble(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_TENSOR:
            {
                IntTensor intTensor = (IntTensor) val;
                long size = intTensor.size();
                intTensor = intTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(intTensor.getInt(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case LONG_TENSOR:
            {
                LongTensor longTensor = (LongTensor) val;
                long size = longTensor.size();
                longTensor = longTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(longTensor.getLong(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case BOOLEAN_TENSOR:
            {
                BoolTensor boolTensor = (BoolTensor) val;
                long size = boolTensor.size();
                boolTensor = boolTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(boolTensor.getBoolean(i) ? 1 : 0);
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case UBYTE_TENSOR:
            {
                UByteTensor ubyteTensor = (UByteTensor) val;
                long size = ubyteTensor.size();
                ubyteTensor = ubyteTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(ubyteTensor.getUByte(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING_TENSOR:
            {
                StringTensor stringTensor = (StringTensor) val;
                long size = stringTensor.size();
                stringTensor = stringTensor.reshape(new Shape(size));
                BytesList.Builder bb = BytesList.newBuilder();
                for (long i = 0; i < size; i += 1) {
                    bb.addValue(castAsBytes(stringTensor.getString(i)));
                }
                featureBuilder.setBytesList(bb);
                break;
            }
        case BYTE_TENSOR:
        default:
            throw new RuntimeException("Unsupported data type for TF");
    }
    return featureBuilder.build();
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) Feature(org.tensorflow.proto.example.Feature) FloatList(org.tensorflow.proto.example.FloatList) Int64List(org.tensorflow.proto.example.Int64List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) ByteTensor(com.alibaba.alink.common.linalg.tensor.ByteTensor)

Example 2 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class VectorToTensorStreamOpTest method testVectorToTensorStreamOp.

@Test
public void testVectorToTensorStreamOp() throws Exception {
    List<Row> data = Collections.singletonList(Row.of("0.0 0.1 1.0 1.1 2.0 2.1"));
    MemSourceStreamOp memSourceStreamOp = new MemSourceStreamOp(data, "vec string");
    CollectSinkStreamOp collectSinkStreamOp = memSourceStreamOp.link(new VectorToTensorStreamOp().setSelectedCol("vec")).link(new CollectSinkStreamOp());
    StreamOperator.execute();
    List<Row> ret = collectSinkStreamOp.getAndRemoveValues();
    Assert.assertEquals(1, ret.size());
    Assert.assertTrue(ret.get(0).getField(0) instanceof DoubleTensor);
    Assert.assertEquals(TensorUtil.getTensor("DOUBLE#6#0.0 0.1 1.0 1.1 2.0 2.1"), ret.get(0).getField(0));
}
Also used : MemSourceStreamOp(com.alibaba.alink.operator.stream.source.MemSourceStreamOp) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) CollectSinkStreamOp(com.alibaba.alink.operator.stream.sink.CollectSinkStreamOp) Row(org.apache.flink.types.Row) Test(org.junit.Test)

Example 3 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class StringTFTensorConversionImpl method encodeBatchTensor.

@Override
public String[] encodeBatchTensor(Tensor<?> tensor, int batchAxis) {
    long[] shape = tensor.shape().asArray();
    long batchSize = shape[batchAxis];
    if (TString.DTYPE.equals(tensor.dataType())) {
        StringTensor stringTensor = (StringTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        StringTensor[] stringTensors = unstack(stringTensor, batchAxis, null);
        return Arrays.stream(stringTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TString) d.data(), String.class)).map(d -> String.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TBool.DTYPE.equals(tensor.dataType())) {
        BoolTensor boolTensor = (BoolTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        BoolTensor[] boolTensors = unstack(boolTensor, batchAxis, null);
        return Arrays.stream(boolTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TBool) d.data())).map(d -> Booleans.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt32.DTYPE.equals(tensor.dataType())) {
        IntTensor intTensor = (IntTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        IntTensor[] intTensors = unstack(intTensor, batchAxis, null);
        return Arrays.stream(intTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt32) d.data())).map(d -> Ints.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt64.DTYPE.equals(tensor.dataType())) {
        LongTensor longTensor = (LongTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        LongTensor[] longTensors = unstack(longTensor, batchAxis, null);
        return Arrays.stream(longTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt64) d.data())).map(d -> Longs.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat32.DTYPE.equals(tensor.dataType())) {
        FloatTensor floatTensor = (FloatTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        FloatTensor[] floatTensors = unstack(floatTensor, batchAxis, null);
        return Arrays.stream(floatTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat32) d.data())).map(d -> Floats.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat64.DTYPE.equals(tensor.dataType())) {
        DoubleTensor doubleTensor = (DoubleTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        DoubleTensor[] doubleTensors = unstack(doubleTensor, batchAxis, null);
        return Arrays.stream(doubleTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat64) d.data())).map(d -> Doubles.join(SEP_CHAR, d)).toArray(String[]::new);
    }
    throw new UnsupportedOperationException("Unsupported dtype: " + tensor.dataType());
}
Also used : TString(org.tensorflow.types.TString) Booleans(com.google.common.primitives.Booleans) Arrays(java.util.Arrays) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) TFloat32(org.tensorflow.types.TFloat32) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) StdArrays(org.tensorflow.ndarray.StdArrays) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TType(org.tensorflow.types.family.TType) Tensor(org.tensorflow.Tensor) TInt32(org.tensorflow.types.TInt32) Longs(com.google.common.primitives.Longs) TString(org.tensorflow.types.TString) Floats(com.google.common.primitives.Floats) TBool(org.tensorflow.types.TBool) TFloat64(org.tensorflow.types.TFloat64) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Preconditions(org.apache.flink.util.Preconditions) StringUtils(org.apache.flink.util.StringUtils) Ints(com.google.common.primitives.Ints) DataType(org.tensorflow.proto.framework.DataType) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) Doubles(com.google.common.primitives.Doubles) Tensor.unstack(com.alibaba.alink.common.linalg.tensor.Tensor.unstack) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TInt64(org.tensorflow.types.TInt64) Shape(org.tensorflow.ndarray.Shape) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TString(org.tensorflow.types.TString) TFloat32(org.tensorflow.types.TFloat32) TInt32(org.tensorflow.types.TInt32) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor)

Example 4 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class TensorTFTensorConversionImpl method parseDoubleTensor.

@Override
public Tensor<TFloat64> parseDoubleTensor(com.alibaba.alink.common.linalg.tensor.Tensor<?>[] tensors, long[] shape) {
    DoubleTensor stackedTensor = new DoubleTensor(new com.alibaba.alink.common.linalg.tensor.Shape(shape));
    stack(castArrayType(tensors, DoubleTensor[].class), 0, stackedTensor);
    // noinspection unchecked
    return (Tensor<TFloat64>) TFTensorConversionUtils.toTFTensor(stackedTensor);
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor)

Example 5 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class TFTensorConversionUtilsTest method testToTFTensor.

@Test
public void testToTFTensor() {
    DoubleTensor doubleTensor = new DoubleTensor(new double[][] { { 1., 2. }, { 3., 4. } });
    Tensor<TFloat64> tensor = (Tensor<TFloat64>) TFTensorConversionUtils.toTFTensor(doubleTensor);
    Assert.assertEquals("DOUBLE", tensor.dataType().name());
    Assert.assertArrayEquals(new long[] { 2, 2 }, tensor.shape().asArray());
    double[] arr = new double[4];
    DoubleDataBuffer buffer = DataBuffers.of(arr, false, false);
    tensor.data().read(buffer);
    Assert.assertArrayEquals(new double[] { 1., 2., 3., 4. }, arr, 1e-9);
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) DoubleDataBuffer(org.tensorflow.ndarray.buffer.DoubleDataBuffer) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Tensor(org.tensorflow.Tensor) TFloat64(org.tensorflow.types.TFloat64) Test(org.junit.Test)

Aggregations

DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)21 Test (org.junit.Test)18 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)15 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)10 Mapper (com.alibaba.alink.common.mapper.Mapper)10 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)10 Params (org.apache.flink.ml.api.misc.param.Params)10 TableSchema (org.apache.flink.table.api.TableSchema)10 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)9 Shape (com.alibaba.alink.common.linalg.tensor.Shape)8 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)7 Row (org.apache.flink.types.Row)5 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 Tensor (org.tensorflow.Tensor)4 TFloat64 (org.tensorflow.types.TFloat64)4 VectorToTensorParams (com.alibaba.alink.params.dataproc.VectorToTensorParams)3 List (java.util.List)3