use of ai.djl.ModelException in project djl-demo by deepjavalibrary.
the class Handler method handleRequest.
@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
LambdaLogger logger = context.getLogger();
String input = Utils.toString(is);
try {
Request request = GSON.fromJson(input, Request.class);
String url = request.getInputImageUrl();
String artifactId = request.getArtifactId();
Map<String, String> filters = request.getFilters();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optArtifactId(artifactId).optFilters(filters).build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Image image = ImageFactory.getInstance().fromUrl(url);
List<Classifications.Classification> result = predictor.predict(image).topK(5);
os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
}
} catch (RuntimeException | ModelException | TranslateException e) {
logger.log("Failed handle input: " + input);
logger.log(e.toString());
String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
os.write(msg.getBytes(StandardCharsets.UTF_8));
}
}
use of ai.djl.ModelException in project djl-demo by deepjavalibrary.
the class InferController method mnistImage.
@PostMapping("/mnistImage")
public ResultBean mnistImage(@RequestParam(value = "imageFile") MultipartFile imageFile) {
try (InputStream ins = imageFile.getInputStream()) {
String result = inferService.getImageInfo(ins);
String base64Img = Base64.encodeBase64String(imageFile.getBytes());
return ResultBean.success().add("result", result).add("base64Img", "data:image/jpeg;base64," + base64Img);
} catch (IOException | ModelException | TranslateException e) {
logger.error(e.getMessage(), e);
return ResultBean.failure().add("errors", e.getMessage());
}
}
use of ai.djl.ModelException in project djl by deepjavalibrary.
the class TrtTest method testSerializedEngine.
@Test
public void testSerializedEngine() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
Device device = engine.defaultDevice();
if (!device.isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
String sm = CudaUtils.getComputeCapability(device.getDeviceId());
Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity_" + sm + ".trt")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
try (ZooModel<float[], float[]> model = criteria.loadModel();
Predictor<float[], float[]> predictor = model.newPredictor()) {
float[] data = new float[] { 1, 2, 3, 4 };
float[] ret = predictor.predict(data);
Assert.assertEquals(ret, data);
}
}
use of ai.djl.ModelException in project djl by deepjavalibrary.
the class TrtTest method testTrtOnnx.
@Test
public void testTrtOnnx() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity.onnx")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
try (ZooModel<float[], float[]> model = criteria.loadModel();
Predictor<float[], float[]> predictor = model.newPredictor()) {
float[] data = new float[] { 1, 2, 3, 4 };
float[] ret = predictor.predict(data);
Assert.assertEquals(ret, data);
}
}
use of ai.djl.ModelException in project djl by deepjavalibrary.
the class TrtTest method testTrtUff.
@Test
public void testTrtUff() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
List<String> synset = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
ImageClassificationTranslator translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).optSynset(synset).optApplySoftmax(true).addTransform(new ToTensor()).optBatchifier(null).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://resources.djl.ai/test-models/tensorrt/lenet5.zip").optTranslator(translator).optEngine("TensorRT").build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Path path = Paths.get("../../examples/src/test/resources/0.png");
Image image = ImageFactory.getInstance().fromFile(path);
Classifications ret = predictor.predict(image);
Assert.assertEquals(ret.best().getClassName(), "0");
}
}
Aggregations