Search in sources :

Example 21 with DoubleTensor

use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.

the class KerasSequentialClassifierBatchOpTest method testTFHubLayer.

@Category(DLTest.class)
@Test
public void testTFHubLayer() throws Exception {
    int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
    BatchOperator.setParallelism(1);
    Random random = new Random();
    int n = 1000;
    int nTimesteps = 96;
    int nvars = 3;
    List<Row> rows = new ArrayList<>();
    for (int nn = 0; nn < n; nn += 1) {
        double[][][] xArr = new double[nTimesteps][nTimesteps][nvars];
        for (int i = 0; i < nTimesteps; i += 1) {
            for (int j = 0; j < nTimesteps; j += 1) {
                for (int k = 0; k < nvars; k += 1) {
                    xArr[i][j][k] = random.nextFloat();
                }
            }
        }
        DoubleTensor x = new DoubleTensor(xArr);
        int label = random.nextInt(2);
        rows.add(Row.of(x, label));
    }
    BatchOperator<?> source = new MemSourceBatchOp(rows, "tensor DOUBLE_TENSOR, label int");
    KerasSequentialClassifierTrainBatchOp trainBatchOp = new KerasSequentialClassifierTrainBatchOp().setTensorCol("tensor").setLabelCol("label").setLayers("Reshape((96, 96, 3))", // input_shape=(96,96,3))",
    "hub.KerasLayer('https://hub.tensorflow.google.cn/tensorflow/efficientnet/b0/classification/1')", "Flatten()").setCheckpointFilePath(PythonFileUtils.createTempDir("keras_sequential_train_").toString()).setNumEpochs(1).linkFrom(source);
    KerasSequentialClassifierPredictBatchOp predictBatchOp = new KerasSequentialClassifierPredictBatchOp().setPredictionCol("pred").setPredictionDetailCol("pred_detail").setReservedCols("label").linkFrom(trainBatchOp, source);
    predictBatchOp.lazyPrint(10);
    BatchOperator.execute();
    BatchOperator.setParallelism(savedParallelism);
}
Also used : MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) DoubleTensor(com.alibaba.alink.common.linalg.tensor.DoubleTensor) Random(java.util.Random) ArrayList(java.util.ArrayList) Row(org.apache.flink.types.Row) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

DoubleTensor (com.alibaba.alink.common.linalg.tensor.DoubleTensor)21 Test (org.junit.Test)18 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)15 StringTensor (com.alibaba.alink.common.linalg.tensor.StringTensor)10 Mapper (com.alibaba.alink.common.mapper.Mapper)10 TypeInformation (org.apache.flink.api.common.typeinfo.TypeInformation)10 Params (org.apache.flink.ml.api.misc.param.Params)10 TableSchema (org.apache.flink.table.api.TableSchema)10 Tensor (com.alibaba.alink.common.linalg.tensor.Tensor)9 Shape (com.alibaba.alink.common.linalg.tensor.Shape)8 ToTensorParams (com.alibaba.alink.params.dataproc.ToTensorParams)7 Row (org.apache.flink.types.Row)5 BoolTensor (com.alibaba.alink.common.linalg.tensor.BoolTensor)4 IntTensor (com.alibaba.alink.common.linalg.tensor.IntTensor)4 LongTensor (com.alibaba.alink.common.linalg.tensor.LongTensor)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 Tensor (org.tensorflow.Tensor)4 TFloat64 (org.tensorflow.types.TFloat64)4 VectorToTensorParams (com.alibaba.alink.params.dataproc.VectorToTensorParams)3 List (java.util.List)3