use of com.alibaba.alink.common.linalg.tensor.DoubleTensor in project Alink by alibaba.
the class KerasSequentialClassifierBatchOpTest method testTFHubLayer.
@Category(DLTest.class)
@Test
public void testTFHubLayer() throws Exception {
int savedParallelism = MLEnvironmentFactory.getDefault().getExecutionEnvironment().getParallelism();
BatchOperator.setParallelism(1);
Random random = new Random();
int n = 1000;
int nTimesteps = 96;
int nvars = 3;
List<Row> rows = new ArrayList<>();
for (int nn = 0; nn < n; nn += 1) {
double[][][] xArr = new double[nTimesteps][nTimesteps][nvars];
for (int i = 0; i < nTimesteps; i += 1) {
for (int j = 0; j < nTimesteps; j += 1) {
for (int k = 0; k < nvars; k += 1) {
xArr[i][j][k] = random.nextFloat();
}
}
}
DoubleTensor x = new DoubleTensor(xArr);
int label = random.nextInt(2);
rows.add(Row.of(x, label));
}
BatchOperator<?> source = new MemSourceBatchOp(rows, "tensor DOUBLE_TENSOR, label int");
KerasSequentialClassifierTrainBatchOp trainBatchOp = new KerasSequentialClassifierTrainBatchOp().setTensorCol("tensor").setLabelCol("label").setLayers("Reshape((96, 96, 3))", // input_shape=(96,96,3))",
"hub.KerasLayer('https://hub.tensorflow.google.cn/tensorflow/efficientnet/b0/classification/1')", "Flatten()").setCheckpointFilePath(PythonFileUtils.createTempDir("keras_sequential_train_").toString()).setNumEpochs(1).linkFrom(source);
KerasSequentialClassifierPredictBatchOp predictBatchOp = new KerasSequentialClassifierPredictBatchOp().setPredictionCol("pred").setPredictionDetailCol("pred_detail").setReservedCols("label").linkFrom(trainBatchOp, source);
predictBatchOp.lazyPrint(10);
BatchOperator.execute();
BatchOperator.setParallelism(savedParallelism);
}
Aggregations