Search in sources :

Example 1 with TranslateException

use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.

the class PaintView method runInference.

@SuppressLint("DefaultLocale")
public void runInference() {
    RectF bound = maxBound.getBound();
    int x = (int) bound.left;
    int y = (int) bound.top;
    int width = (int) Math.ceil(bound.width());
    int height = (int) Math.ceil(bound.height());
    // width must be >0
    if (width <= 0)
        return;
    // y+height must be <= bitmap.height()
    if (y + height > bitmap.getHeight())
        return;
    // do crop
    Bitmap bmp = Bitmap.createBitmap(bitmap, x, y, width, height);
    // do scaling
    Bitmap bmp64 = Bitmap.createScaledBitmap(bmp, 64, 64, true);
    try {
        Classifications classifications = predictor.predict(factory.fromImage(bmp64));
        imageView.setImageBitmap(bmp);
        List<Classifications.Classification> list = classifications.topK(3);
        StringBuilder sb = new StringBuilder();
        for (Classifications.Classification classification : list) {
            sb.append(classification.getClassName()).append(": ").append(String.format("%.2f%%", 100 * classification.getProbability())).append("\n");
        }
        textView.setText(sb.toString());
    } catch (TranslateException e) {
        Log.e("DoodleDraw", null, e);
    }
}
Also used : RectF(android.graphics.RectF) Bitmap(android.graphics.Bitmap) Classifications(ai.djl.modality.Classifications) TranslateException(ai.djl.translate.TranslateException) SuppressLint(android.annotation.SuppressLint) Paint(android.graphics.Paint) SuppressLint(android.annotation.SuppressLint)

Example 2 with TranslateException

use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.

the class Handler method handleRequest.

@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
    LambdaLogger logger = context.getLogger();
    String input = Utils.toString(is);
    try {
        Request request = GSON.fromJson(input, Request.class);
        String url = request.getInputImageUrl();
        String artifactId = request.getArtifactId();
        Map<String, String> filters = request.getFilters();
        Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optArtifactId(artifactId).optFilters(filters).build();
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
            Predictor<Image, Classifications> predictor = model.newPredictor()) {
            Image image = ImageFactory.getInstance().fromUrl(url);
            List<Classifications.Classification> result = predictor.predict(image).topK(5);
            os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
        }
    } catch (RuntimeException | ModelException | TranslateException e) {
        logger.log("Failed handle input: " + input);
        logger.log(e.toString());
        String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
        os.write(msg.getBytes(StandardCharsets.UTF_8));
    }
}
Also used : Classifications(ai.djl.modality.Classifications) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) Image(ai.djl.modality.cv.Image) LambdaLogger(com.amazonaws.services.lambda.runtime.LambdaLogger)

Example 3 with TranslateException

use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.

the class InferController method mnistImage.

@PostMapping("/mnistImage")
public ResultBean mnistImage(@RequestParam(value = "imageFile") MultipartFile imageFile) {
    try (InputStream ins = imageFile.getInputStream()) {
        String result = inferService.getImageInfo(ins);
        String base64Img = Base64.encodeBase64String(imageFile.getBytes());
        return ResultBean.success().add("result", result).add("base64Img", "data:image/jpeg;base64," + base64Img);
    } catch (IOException | ModelException | TranslateException e) {
        logger.error(e.getMessage(), e);
        return ResultBean.failure().add("errors", e.getMessage());
    }
}
Also used : ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) InputStream(java.io.InputStream) IOException(java.io.IOException) PostMapping(org.springframework.web.bind.annotation.PostMapping)

Example 4 with TranslateException

use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.

the class FixedBucketSamplerTest method testFixedBucketSampler.

@Test
public void testFixedBucketSampler() throws IOException, TranslateException {
    FixedBucketSampler fixedBucketSampler = new FixedBucketSampler(10, 10, false);
    TatoebaEnglishFrenchDataset dataset = TatoebaEnglishFrenchDataset.builder().setSampling(fixedBucketSampler).optDataBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.zeros(new Shape(1)), 10).build()).optLabelBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.ones(new Shape(1)), 10).build()).optLimit(200).build();
    dataset.prepare();
    Iterator<List<Long>> iterator = fixedBucketSampler.sample(dataset);
    long count = 0;
    Set<Long> indicesSet = new HashSet<>();
    while (iterator.hasNext()) {
        List<Long> indices = iterator.next();
        indicesSet.addAll(indices);
        count += indices.size();
    }
    Assert.assertEquals(count, dataset.size());
    Assert.assertEquals(indicesSet.size(), dataset.size());
    fixedBucketSampler = new FixedBucketSampler(10, 5, true);
    iterator = fixedBucketSampler.sample(dataset);
    count = 0;
    indicesSet.clear();
    while (iterator.hasNext()) {
        List<Long> indices = iterator.next();
        indicesSet.addAll(indices);
        count = count + indices.size();
    }
    Assert.assertEquals(count, dataset.size());
    Assert.assertEquals(indicesSet.size(), dataset.size());
}
Also used : HashSet(java.util.HashSet) List(java.util.List) TranslateException(ai.djl.translate.TranslateException) Iterator(java.util.Iterator) Assert(org.testng.Assert) TatoebaEnglishFrenchDataset(ai.djl.basicdataset.nlp.TatoebaEnglishFrenchDataset) FixedBucketSampler(ai.djl.basicdataset.utils.FixedBucketSampler) Shape(ai.djl.ndarray.types.Shape) Set(java.util.Set) IOException(java.io.IOException) Test(org.testng.annotations.Test) PaddingStackBatchifier(ai.djl.translate.PaddingStackBatchifier) Shape(ai.djl.ndarray.types.Shape) TatoebaEnglishFrenchDataset(ai.djl.basicdataset.nlp.TatoebaEnglishFrenchDataset) List(java.util.List) FixedBucketSampler(ai.djl.basicdataset.utils.FixedBucketSampler) HashSet(java.util.HashSet) Test(org.testng.annotations.Test)

Example 5 with TranslateException

use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.

the class TrtTest method testSerializedEngine.

@Test
public void testSerializedEngine() throws ModelException, IOException, TranslateException {
    Engine engine;
    try {
        engine = Engine.getEngine("TensorRT");
    } catch (Exception ignore) {
        throw new SkipException("Your os configuration doesn't support TensorRT.");
    }
    Device device = engine.defaultDevice();
    if (!device.isGpu()) {
        throw new SkipException("TensorRT only support GPU.");
    }
    String sm = CudaUtils.getComputeCapability(device.getDeviceId());
    Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity_" + sm + ".trt")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
    try (ZooModel<float[], float[]> model = criteria.loadModel();
        Predictor<float[], float[]> predictor = model.newPredictor()) {
        float[] data = new float[] { 1, 2, 3, 4 };
        float[] ret = predictor.predict(data);
        Assert.assertEquals(ret, data);
    }
}
Also used : Device(ai.djl.Device) SkipException(org.testng.SkipException) Engine(ai.djl.engine.Engine) SkipException(org.testng.SkipException) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) IOException(java.io.IOException) Test(org.testng.annotations.Test)

Aggregations

TranslateException (ai.djl.translate.TranslateException)20 IOException (java.io.IOException)9 ModelException (ai.djl.ModelException)6 Classifications (ai.djl.modality.Classifications)6 Engine (ai.djl.engine.Engine)5 Shape (ai.djl.ndarray.types.Shape)4 Application (ai.djl.Application)3 Model (ai.djl.Model)3 Image (ai.djl.modality.cv.Image)3 NDList (ai.djl.ndarray.NDList)3 PaddingStackBatchifier (ai.djl.translate.PaddingStackBatchifier)3 Locale (java.util.Locale)3 Device (ai.djl.Device)2 MalformedModelException (ai.djl.MalformedModelException)2 Arguments (ai.djl.examples.training.util.Arguments)2 Predictor (ai.djl.inference.Predictor)2 Metrics (ai.djl.metric.Metrics)2 Input (ai.djl.modality.Input)2 ToTensor (ai.djl.modality.cv.transform.ToTensor)2 TextEmbedding (ai.djl.modality.nlp.embedding.TextEmbedding)2