use of com.alibaba.alink.common.linalg.tensor.LongTensor in project Alink by alibaba.
the class BaseTFSavedModelPredictMapperTest method testTensor.
@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
AlinkGlobalConfiguration.setPrintProcessInfo(true);
PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
int batchSize = 3;
List<Row> rows = new ArrayList<>();
for (int i = 0; i < 1000; i += 1) {
Row row = Row.of(new LongTensor((new Shape(batchSize))), new FloatTensor(new Shape(batchSize, 28, 28)));
rows.add(row);
}
BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG_TENSOR, image FLOAT_TENSOR");
String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
String workDir = PythonFileUtils.createTempDir("temp_").toString();
String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
String localModelPath = workDir + File.separator + fn;
System.out.println("localModelPath:" + localModelPath);
if (localModelPath.endsWith(".zip")) {
File target = new File(localModelPath).getParentFile();
ZipFileUtil.unZip(new File(localModelPath), target);
localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
}
Params params = new Params();
params.set(HasModelPath.MODEL_PATH, localModelPath);
params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG_TENSOR, probabilities FLOAT_TENSOR");
BaseTFSavedModelPredictMapper baseTFSavedModelPredictMapper = new BaseTFSavedModelPredictMapper(data.getSchema(), params);
baseTFSavedModelPredictMapper.open();
Assert.assertEquals(TableSchema.builder().field("label", TensorTypes.LONG_TENSOR).field("image", TensorTypes.FLOAT_TENSOR).field("classes", TensorTypes.LONG_TENSOR).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictMapper.getOutputSchema());
for (Row row : rows) {
Row output = baseTFSavedModelPredictMapper.map(row);
Assert.assertEquals(row.getField(0), output.getField(0));
Assert.assertEquals(row.getField(1), output.getField(1));
Assert.assertArrayEquals(((LongTensor) output.getField(2)).shape(), new long[] { batchSize });
Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { batchSize, 10 });
}
baseTFSavedModelPredictMapper.close();
}
Aggregations