Search in sources :

Example 1 with LongTensor

use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.

the class TFExampleConversionV2 method javaToFeature.

/**
 * convert java object to tensorflow feature.
 *
 * @param dt  java object data type.
 * @param val given java object.
 * @return tensorflow feature.
 */
public static Feature javaToFeature(DataTypesV2 dt, Object val) {
    Feature.Builder featureBuilder = Feature.newBuilder();
    FloatList.Builder floatListBuilder = FloatList.newBuilder();
    Int64List.Builder int64ListBuilder = Int64List.newBuilder();
    // When dt is TENSOR, find the exact type first.
    if (DataTypesV2.TENSOR.equals(dt)) {
        if (val instanceof FloatTensor) {
            dt = DataTypesV2.FLOAT_TENSOR;
        } else if (val instanceof DoubleTensor) {
            dt = DataTypesV2.DOUBLE_TENSOR;
        } else if (val instanceof IntTensor) {
            dt = DataTypesV2.INT_TENSOR;
        } else if (val instanceof LongTensor) {
            dt = DataTypesV2.LONG_TENSOR;
        } else if (val instanceof BoolTensor) {
            dt = DataTypesV2.BOOLEAN_TENSOR;
        } else if (val instanceof UByteTensor) {
            dt = DataTypesV2.UBYTE_TENSOR;
        } else if (val instanceof StringTensor) {
            dt = DataTypesV2.STRING_TENSOR;
        } else if (val instanceof ByteTensor) {
            dt = DataTypesV2.BYTE_TENSOR;
        }
    }
    switch(dt) {
        case FLOAT_16:
        case FLOAT_32:
        case FLOAT_64:
            {
                floatListBuilder.addValue((Float) val);
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_8:
        case INT_16:
        case INT_32:
        case INT_64:
        case UINT_8:
        case UINT_16:
        case UINT_32:
        case UINT_64:
            {
                int64ListBuilder.addValue(castAsLong(val));
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING:
            {
                BytesList.Builder bb = BytesList.newBuilder();
                bb.addValue(castAsBytes(val));
                featureBuilder.setBytesList(bb);
                break;
            }
        case FLOAT_TENSOR:
            {
                FloatTensor floatTensor = (FloatTensor) val;
                long size = floatTensor.size();
                floatTensor = floatTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue(floatTensor.getFloat(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case DOUBLE_TENSOR:
            {
                DoubleTensor doubleTensor = (DoubleTensor) val;
                long size = doubleTensor.size();
                doubleTensor = doubleTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    floatListBuilder.addValue((float) doubleTensor.getDouble(i));
                }
                featureBuilder.setFloatList(floatListBuilder);
                break;
            }
        case INT_TENSOR:
            {
                IntTensor intTensor = (IntTensor) val;
                long size = intTensor.size();
                intTensor = intTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(intTensor.getInt(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case LONG_TENSOR:
            {
                LongTensor longTensor = (LongTensor) val;
                long size = longTensor.size();
                longTensor = longTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(longTensor.getLong(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case BOOLEAN_TENSOR:
            {
                BoolTensor boolTensor = (BoolTensor) val;
                long size = boolTensor.size();
                boolTensor = boolTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(boolTensor.getBoolean(i) ? 1 : 0);
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case UBYTE_TENSOR:
            {
                UByteTensor ubyteTensor = (UByteTensor) val;
                long size = ubyteTensor.size();
                ubyteTensor = ubyteTensor.reshape(new Shape(size));
                for (long i = 0; i < size; i += 1) {
                    int64ListBuilder.addValue(ubyteTensor.getUByte(i));
                }
                featureBuilder.setInt64List(int64ListBuilder);
                break;
            }
        case STRING_TENSOR:
            {
                StringTensor stringTensor = (StringTensor) val;
                long size = stringTensor.size();
                stringTensor = stringTensor.reshape(new Shape(size));
                BytesList.Builder bb = BytesList.newBuilder();
                for (long i = 0; i < size; i += 1) {
                    bb.addValue(castAsBytes(stringTensor.getString(i)));
                }
                featureBuilder.setBytesList(bb);
                break;
            }
        case BYTE_TENSOR:
        default:
            throw new RuntimeException("Unsupported data type for TF");
    }
    return featureBuilder.build();
}
Also used : DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) Feature(org.tensorflow.proto.example.Feature) FloatList(org.tensorflow.proto.example.FloatList) Int64List(org.tensorflow.proto.example.Int64List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) UByteTensor(com.alibaba.alink.common.linalg.tensor.UByteTensor) ByteTensor(com.alibaba.alink.common.linalg.tensor.ByteTensor)

Example 2 with LongTensor

use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.

the class StringTFTensorConversionImpl method encodeBatchTensor.

@Override
public String[] encodeBatchTensor(Tensor<?> tensor, int batchAxis) {
    long[] shape = tensor.shape().asArray();
    long batchSize = shape[batchAxis];
    if (TString.DTYPE.equals(tensor.dataType())) {
        StringTensor stringTensor = (StringTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        StringTensor[] stringTensors = unstack(stringTensor, batchAxis, null);
        return Arrays.stream(stringTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TString) d.data(), String.class)).map(d -> String.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TBool.DTYPE.equals(tensor.dataType())) {
        BoolTensor boolTensor = (BoolTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        BoolTensor[] boolTensors = unstack(boolTensor, batchAxis, null);
        return Arrays.stream(boolTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TBool) d.data())).map(d -> Booleans.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt32.DTYPE.equals(tensor.dataType())) {
        IntTensor intTensor = (IntTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        IntTensor[] intTensors = unstack(intTensor, batchAxis, null);
        return Arrays.stream(intTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt32) d.data())).map(d -> Ints.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TInt64.DTYPE.equals(tensor.dataType())) {
        LongTensor longTensor = (LongTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        LongTensor[] longTensors = unstack(longTensor, batchAxis, null);
        return Arrays.stream(longTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TInt64) d.data())).map(d -> Longs.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat32.DTYPE.equals(tensor.dataType())) {
        FloatTensor floatTensor = (FloatTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        FloatTensor[] floatTensors = unstack(floatTensor, batchAxis, null);
        return Arrays.stream(floatTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat32) d.data())).map(d -> Floats.join(SEP_CHAR, d)).toArray(String[]::new);
    } else if (TFloat64.DTYPE.equals(tensor.dataType())) {
        DoubleTensor doubleTensor = (DoubleTensor) TFTensorConversionUtils.fromTFTensor(tensor);
        DoubleTensor[] doubleTensors = unstack(doubleTensor, batchAxis, null);
        return Arrays.stream(doubleTensors).map(TFTensorConversionUtils::toTFTensor).map(TF2TensorUtils::squeezeTensor).map(d -> StdArrays.array1dCopyOf((TFloat64) d.data())).map(d -> Doubles.join(SEP_CHAR, d)).toArray(String[]::new);
    }
    throw new UnsupportedOperationException("Unsupported dtype: " + tensor.dataType());
}
Also used : TString(org.tensorflow.types.TString) Booleans(com.google.common.primitives.Booleans) Arrays(java.util.Arrays) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) TFloat32(org.tensorflow.types.TFloat32) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) StdArrays(org.tensorflow.ndarray.StdArrays) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TType(org.tensorflow.types.family.TType) Tensor(org.tensorflow.Tensor) TInt32(org.tensorflow.types.TInt32) Longs(com.google.common.primitives.Longs) TString(org.tensorflow.types.TString) Floats(com.google.common.primitives.Floats) TBool(org.tensorflow.types.TBool) TFloat64(org.tensorflow.types.TFloat64) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Preconditions(org.apache.flink.util.Preconditions) StringUtils(org.apache.flink.util.StringUtils) Ints(com.google.common.primitives.Ints) DataType(org.tensorflow.proto.framework.DataType) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) Doubles(com.google.common.primitives.Doubles) Tensor.unstack(com.alibaba.alink.common.linalg.tensor.Tensor.unstack) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) TInt64(org.tensorflow.types.TInt64) Shape(org.tensorflow.ndarray.Shape) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) TFTensorConversionUtils(com.alibaba.alink.common.dl.utils.TFTensorConversionUtils) TF2TensorUtils(com.alibaba.alink.common.dl.utils.TF2TensorUtils) TString(org.tensorflow.types.TString) TFloat32(org.tensorflow.types.TFloat32) TInt32(org.tensorflow.types.TInt32) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor)

Example 3 with LongTensor

use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.

the class TensorTFTensorConversionImpl method parseLongTensor.

@Override
public Tensor<TInt64> parseLongTensor(com.alibaba.alink.common.linalg.tensor.Tensor<?>[] tensors, long[] shape) {
    LongTensor stackedTensor = new LongTensor(new com.alibaba.alink.common.linalg.tensor.Shape(shape));
    stack(castArrayType(tensors, LongTensor[].class), 0, stackedTensor);
    // noinspection unchecked
    return (Tensor<TInt64>) TFTensorConversionUtils.toTFTensor(stackedTensor);
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor)

Example 4 with LongTensor

use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.

the class TFSavedModelPredictMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    int batchSize = 3;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip");
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
    TFSavedModelPredictMapper tfSavedModelPredictMapper = new TFSavedModelPredictMapper(data.getSchema(), params);
    tfSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), tfSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = tfSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
    }
    tfSavedModelPredictMapper.close();
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 5 with LongTensor

use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.

the class TensorTFTensorConversionTest method testLongTensor.

@Test
public void testLongTensor() {
    Random random = new Random(2021);
    long[][] arr = new long[3][4];
    for (int i = 0; i < 3; i += 1) {
        for (int j = 0; j < 4; j += 1) {
            arr[i][j] = random.nextLong();
        }
    }
    LongTensor data = new LongTensor(arr);
    Tensor<TInt64> std = TInt64.tensorOf(StdArrays.ndCopyOf(arr));
    TensorInfo tensorInfo = createTensorInfo(DataType.DT_INT64, new long[] { 3, 4 });
    // noinspection unchecked
    Tensor<TInt64> tensor = (Tensor<TInt64>) TensorTFTensorConversionImpl.getInstance().parseTensor(data, tensorInfo);
    Assert.assertArrayEquals(std.shape().asArray(), tensor.shape().asArray());
    Assert.assertArrayEquals(StdArrays.array2dCopyOf(std.data()), StdArrays.array2dCopyOf(tensor.data()));
    LongTensor encoded = (LongTensor) TensorTFTensorConversionImpl.getInstance().encodeTensor(tensor);
    for (int i = 0; i < 3; i += 1) {
        for (int j = 0; j < 4; j += 1) {
            Assert.assertEquals(encoded.getLong(i, j), data.getLong(i, j));
        }
    }
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) BoolTensor(com.alibaba.alink.common.linalg.tensor.BoolTensor) LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) StringTensor(com.alibaba.alink.common.linalg.tensor.StringTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Tensor(org.tensorflow.Tensor) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) Random(java.util.Random) TInt64(org.tensorflow.types.TInt64) TensorInfo(org.tensorflow.proto.framework.TensorInfo) Test(org.junit.Test)

Aggregations

FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)6 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)6 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)4 Shape (com.alibaba.alink.common.linalg.tensor.Shape)3 Test (org.junit.Test)3 Tensor (org.tensorflow.Tensor)3 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)2 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)2 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)2 DLTest (com.alibaba.alink.testutil.categories.DLTest)2 ArrayList (java.util.ArrayList)2 Params (org.apache.flink.ml.api.misc.param.Params)2 Row (org.apache.flink.types.Row)2 Category (org.junit.experimental.categories.Category)2 TensorInfo (org.tensorflow.proto.framework.TensorInfo)2 TF2TensorUtils (com.alibaba.alink.common.dl.utils.TF2TensorUtils)1 TFTensorConversionUtils (com.alibaba.alink.common.dl.utils.TFTensorConversionUtils)1