Search in sources :

Example 6 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class ToTensorMapperTest method testOp.

@Test
public void testOp() throws Exception {
    final DoubleTensor tensor = DoubleTensor.of(TensorUtil.getTensor("FLOAT#6#0.0 0.1 1.0 1.1 2.0 2.1 "));
    final FloatTensor expect = FloatTensor.of(tensor.reshape(new Shape(2L, 3L)));
    Row[] rows = new Row[] { Row.of(tensor.toVector()) };
    MemSourceBatchOp memSourceBatchOp = new MemSourceBatchOp(rows, new String[] { "vec" });
    memSourceBatchOp.link(new ToTensorBatchOp().setSelectedCol("vec").setTensorShape(2, 3).setTensorDataType("float")).lazyCollect(new Consumer<List<Row>>() {

        @Override
        public void accept(List<Row> rows) {
            Assert.assertEquals(expect, TensorUtil.getTensor(rows.get(0).getField(0)));
        }
    });
    BatchOperator.execute();
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ToTensorBatchOp(com.alibaba.alink.operator.batch.dataproc.ToTensorBatchOp) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) List(java.util.List) Row(org.apache.flink.types.Row) Test(org.junit.Test)

Example 7 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class BaseTFSavedModelPredictRowFlatMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(0, new FloatTensor(new Shape(28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath: " + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictRowFlatMapper baseTFSavedModelPredictRowFlatMapper = new BaseTFSavedModelPredictRowFlatMapper(data.getSchema(), params);
    baseTFSavedModelPredictRowFlatMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", TensorTypes.FLOAT_TENSOR).field("classes", Types.LONG).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictRowFlatMapper.getOutputSchema());
    List<Row> outputs = new ArrayList<>();
    ListCollector<Row> collector = new ListCollector<>(outputs);
    for (Row row : rows) {
        baseTFSavedModelPredictRowFlatMapper.flatMap(row, collector);
    }
    baseTFSavedModelPredictRowFlatMapper.close();
    for (int i = 0; i < rows.size(); i += 1) {
        Row row = rows.get(i);
        Row output = outputs.get(i);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { 10 });
    }
}
Also used : Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) ListCollector(org.apache.flink.api.common.functions.util.ListCollector) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 8 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class BaseTFSavedModelPredictRowMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(0, new FloatTensor(new Shape(28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath: " + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictRowMapper baseTFSavedModelPredictRowMapper = new BaseTFSavedModelPredictRowMapper(data.getSchema(), params);
    baseTFSavedModelPredictRowMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", TensorTypes.FLOAT_TENSOR).field("classes", Types.LONG).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictRowMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictRowMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { 10 });
    }
    baseTFSavedModelPredictRowMapper.close();
}
Also used : Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 9 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class TFSavedModelPredictMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    int batchSize = 3;
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip");
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
    TFSavedModelPredictMapper tfSavedModelPredictMapper = new TFSavedModelPredictMapper(data.getSchema(), params);
    tfSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), tfSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = tfSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
    }
    tfSavedModelPredictMapper.close();
}
Also used : LongTensor(com.alibaba.alink.common.linalg.tensor.LongTensor) Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 10 with Shape

use of com.alibaba.alink.common.linalg.tensor.Shape in project Alink by alibaba.

the class BertEmbeddingExtractorMapperTest method test.

@Test
public void test() throws Exception {
    int totalLayers = 13;
    int length = 10;
    int embeddingSize = 768;
    long[] shape = new long[] { totalLayers, length, embeddingSize };
    int layer = -1;
    Random random = new Random();
    float[] buffer = new float[totalLayers * length * embeddingSize];
    for (int i = 0; i < buffer.length; i += 1) {
        buffer[i] = random.nextFloat();
    }
    Params params = new Params();
    params.set(HasLengthCol.LENGTH_COL, "length");
    params.set(HasHiddenStatesCol.HIDDEN_STATES_COL, "hidden_states");
    params.set(HasOutputCol.OUTPUT_COL, "embed");
    params.set(HasReservedColsDefaultAsNull.RESERVED_COLS, new String[] { "text" });
    params.set(HasLayer.LAYER, layer);
    TableSchema dataSchema = TableSchema.builder().field("text", Types.STRING).field("hidden_states", TensorTypes.FLOAT_TENSOR).field("length", TensorTypes.INT_TENSOR).build();
    BertEmbeddingExtractorMapper mapper = new BertEmbeddingExtractorMapper(dataSchema, params);
    mapper.open();
    Row result = mapper.map(Row.of("sequence builders", new FloatTensor(buffer).reshape(new Shape(shape)), new IntTensor(new int[] { length })));
    System.out.println(result);
    mapper.close();
    float[] embed = new float[embeddingSize];
    Arrays.fill(embed, 0);
    int[] p = new int[] { (int) (shape[0] + layer), 0, 0 };
    for (p[1] = 0; p[1] < length; p[1] += 1) {
        for (p[2] = 0; p[2] < embeddingSize; p[2] += 1) {
            int index = BertEmbeddingExtractorMapper.calcIndex(p, shape);
            embed[p[2]] += buffer[index] / length;
        }
    }
    Assert.assertEquals(Floats.join(SEP_CHAR, embed), result.getField(1));
}
Also used : Shape(com.alibaba.alink.common.linalg.tensor.Shape) Random(java.util.Random) TableSchema(org.apache.flink.table.api.TableSchema) BertEmbeddingExtractorMapper(com.alibaba.alink.operator.common.nlp.bert.BertEmbeddingExtractorMapper) Params(org.apache.flink.ml.api.misc.param.Params) IntTensor(com.alibaba.alink.common.linalg.tensor.IntTensor) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) Test(org.junit.Test)

Aggregations

FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)13 Shape (com.alibaba.alink.common.linalg.tensor.Shape)13 Test (org.junit.Test)12 Params (org.apache.flink.ml.api.misc.param.Params)10 DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)8 Row (org.apache.flink.types.Row)7 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)6 TableSchema (org.apache.flink.table.api.TableSchema)6 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)5 Mapper (com.alibaba.alink.common.mapper.Mapper)5 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)5 PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)4 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)4 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)4 DLTest (com.alibaba.alink.testutil.categories.DLTest)4 ArrayList (java.util.ArrayList)4 Category (org.junit.experimental.categories.Category)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)3 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)3 File (java.io.File)3