Search in sources :

Example 1 with ModelException

use of ai.djl.ModelException in project djl-demo by deepjavalibrary.

the class Handler method handleRequest.

@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
    LambdaLogger logger = context.getLogger();
    String input = Utils.toString(is);
    try {
        Request request = GSON.fromJson(input, Request.class);
        String url = request.getInputImageUrl();
        String artifactId = request.getArtifactId();
        Map<String, String> filters = request.getFilters();
        Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optArtifactId(artifactId).optFilters(filters).build();
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
            Predictor<Image, Classifications> predictor = model.newPredictor()) {
            Image image = ImageFactory.getInstance().fromUrl(url);
            List<Classifications.Classification> result = predictor.predict(image).topK(5);
            os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
        }
    } catch (RuntimeException | ModelException | TranslateException e) {
        logger.log("Failed handle input: " + input);
        logger.log(e.toString());
        String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
        os.write(msg.getBytes(StandardCharsets.UTF_8));
    }
}
Also used : Classifications(ai.djl.modality.Classifications) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) Image(ai.djl.modality.cv.Image) LambdaLogger(com.amazonaws.services.lambda.runtime.LambdaLogger)

Example 2 with ModelException

use of ai.djl.ModelException in project djl-demo by deepjavalibrary.

the class InferController method mnistImage.

@PostMapping("/mnistImage")
public ResultBean mnistImage(@RequestParam(value = "imageFile") MultipartFile imageFile) {
    try (InputStream ins = imageFile.getInputStream()) {
        String result = inferService.getImageInfo(ins);
        String base64Img = Base64.encodeBase64String(imageFile.getBytes());
        return ResultBean.success().add("result", result).add("base64Img", "data:image/jpeg;base64," + base64Img);
    } catch (IOException | ModelException | TranslateException e) {
        logger.error(e.getMessage(), e);
        return ResultBean.failure().add("errors", e.getMessage());
    }
}
Also used : ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) InputStream(java.io.InputStream) IOException(java.io.IOException) PostMapping(org.springframework.web.bind.annotation.PostMapping)

Example 3 with ModelException

use of ai.djl.ModelException in project djl by deepjavalibrary.

the class TrtTest method testSerializedEngine.

@Test
public void testSerializedEngine() throws ModelException, IOException, TranslateException {
    Engine engine;
    try {
        engine = Engine.getEngine("TensorRT");
    } catch (Exception ignore) {
        throw new SkipException("Your os configuration doesn't support TensorRT.");
    }
    Device device = engine.defaultDevice();
    if (!device.isGpu()) {
        throw new SkipException("TensorRT only support GPU.");
    }
    String sm = CudaUtils.getComputeCapability(device.getDeviceId());
    Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity_" + sm + ".trt")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
    try (ZooModel<float[], float[]> model = criteria.loadModel();
        Predictor<float[], float[]> predictor = model.newPredictor()) {
        float[] data = new float[] { 1, 2, 3, 4 };
        float[] ret = predictor.predict(data);
        Assert.assertEquals(ret, data);
    }
}
Also used : Device(ai.djl.Device) SkipException(org.testng.SkipException) Engine(ai.djl.engine.Engine) SkipException(org.testng.SkipException) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) IOException(java.io.IOException) Test(org.testng.annotations.Test)

Example 4 with ModelException

use of ai.djl.ModelException in project djl by deepjavalibrary.

the class TrtTest method testTrtOnnx.

@Test
public void testTrtOnnx() throws ModelException, IOException, TranslateException {
    Engine engine;
    try {
        engine = Engine.getEngine("TensorRT");
    } catch (Exception ignore) {
        throw new SkipException("Your os configuration doesn't support TensorRT.");
    }
    if (!engine.defaultDevice().isGpu()) {
        throw new SkipException("TensorRT only support GPU.");
    }
    Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity.onnx")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
    try (ZooModel<float[], float[]> model = criteria.loadModel();
        Predictor<float[], float[]> predictor = model.newPredictor()) {
        float[] data = new float[] { 1, 2, 3, 4 };
        float[] ret = predictor.predict(data);
        Assert.assertEquals(ret, data);
    }
}
Also used : SkipException(org.testng.SkipException) Engine(ai.djl.engine.Engine) SkipException(org.testng.SkipException) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) IOException(java.io.IOException) Test(org.testng.annotations.Test)

Example 5 with ModelException

use of ai.djl.ModelException in project djl by deepjavalibrary.

the class TrtTest method testTrtUff.

@Test
public void testTrtUff() throws ModelException, IOException, TranslateException {
    Engine engine;
    try {
        engine = Engine.getEngine("TensorRT");
    } catch (Exception ignore) {
        throw new SkipException("Your os configuration doesn't support TensorRT.");
    }
    if (!engine.defaultDevice().isGpu()) {
        throw new SkipException("TensorRT only support GPU.");
    }
    List<String> synset = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
    ImageClassificationTranslator translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).optSynset(synset).optApplySoftmax(true).addTransform(new ToTensor()).optBatchifier(null).build();
    Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://resources.djl.ai/test-models/tensorrt/lenet5.zip").optTranslator(translator).optEngine("TensorRT").build();
    try (ZooModel<Image, Classifications> model = criteria.loadModel();
        Predictor<Image, Classifications> predictor = model.newPredictor()) {
        Path path = Paths.get("../../examples/src/test/resources/0.png");
        Image image = ImageFactory.getInstance().fromFile(path);
        Classifications ret = predictor.predict(image);
        Assert.assertEquals(ret.best().getClassName(), "0");
    }
}
Also used : Path(java.nio.file.Path) ImageClassificationTranslator(ai.djl.modality.cv.translator.ImageClassificationTranslator) Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) Image(ai.djl.modality.cv.Image) SkipException(org.testng.SkipException) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) IOException(java.io.IOException) SkipException(org.testng.SkipException) Engine(ai.djl.engine.Engine) Test(org.testng.annotations.Test)

Aggregations

ModelException (ai.djl.ModelException)6 TranslateException (ai.djl.translate.TranslateException)6 IOException (java.io.IOException)4 Engine (ai.djl.engine.Engine)3 Classifications (ai.djl.modality.Classifications)3 Image (ai.djl.modality.cv.Image)3 SkipException (org.testng.SkipException)3 Test (org.testng.annotations.Test)3 ToTensor (ai.djl.modality.cv.transform.ToTensor)2 LambdaLogger (com.amazonaws.services.lambda.runtime.LambdaLogger)2 Device (ai.djl.Device)1 ImageFactory (ai.djl.modality.cv.ImageFactory)1 ImageClassificationTranslator (ai.djl.modality.cv.translator.ImageClassificationTranslator)1 ByteArrayInputStream (java.io.ByteArrayInputStream)1 InputStream (java.io.InputStream)1 Path (java.nio.file.Path)1 PostMapping (org.springframework.web.bind.annotation.PostMapping)1