use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.
the class PaintView method runInference.
@SuppressLint("DefaultLocale")
public void runInference() {
RectF bound = maxBound.getBound();
int x = (int) bound.left;
int y = (int) bound.top;
int width = (int) Math.ceil(bound.width());
int height = (int) Math.ceil(bound.height());
// width must be >0
if (width <= 0)
return;
// y+height must be <= bitmap.height()
if (y + height > bitmap.getHeight())
return;
// do crop
Bitmap bmp = Bitmap.createBitmap(bitmap, x, y, width, height);
// do scaling
Bitmap bmp64 = Bitmap.createScaledBitmap(bmp, 64, 64, true);
try {
Classifications classifications = predictor.predict(factory.fromImage(bmp64));
imageView.setImageBitmap(bmp);
List<Classifications.Classification> list = classifications.topK(3);
StringBuilder sb = new StringBuilder();
for (Classifications.Classification classification : list) {
sb.append(classification.getClassName()).append(": ").append(String.format("%.2f%%", 100 * classification.getProbability())).append("\n");
}
textView.setText(sb.toString());
} catch (TranslateException e) {
Log.e("DoodleDraw", null, e);
}
}
use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.
the class Handler method handleRequest.
@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
LambdaLogger logger = context.getLogger();
String input = Utils.toString(is);
try {
Request request = GSON.fromJson(input, Request.class);
String url = request.getInputImageUrl();
String artifactId = request.getArtifactId();
Map<String, String> filters = request.getFilters();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optArtifactId(artifactId).optFilters(filters).build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Image image = ImageFactory.getInstance().fromUrl(url);
List<Classifications.Classification> result = predictor.predict(image).topK(5);
os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
}
} catch (RuntimeException | ModelException | TranslateException e) {
logger.log("Failed handle input: " + input);
logger.log(e.toString());
String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
os.write(msg.getBytes(StandardCharsets.UTF_8));
}
}
use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.
the class InferController method mnistImage.
@PostMapping("/mnistImage")
public ResultBean mnistImage(@RequestParam(value = "imageFile") MultipartFile imageFile) {
try (InputStream ins = imageFile.getInputStream()) {
String result = inferService.getImageInfo(ins);
String base64Img = Base64.encodeBase64String(imageFile.getBytes());
return ResultBean.success().add("result", result).add("base64Img", "data:image/jpeg;base64," + base64Img);
} catch (IOException | ModelException | TranslateException e) {
logger.error(e.getMessage(), e);
return ResultBean.failure().add("errors", e.getMessage());
}
}
use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class FixedBucketSamplerTest method testFixedBucketSampler.
@Test
public void testFixedBucketSampler() throws IOException, TranslateException {
FixedBucketSampler fixedBucketSampler = new FixedBucketSampler(10, 10, false);
TatoebaEnglishFrenchDataset dataset = TatoebaEnglishFrenchDataset.builder().setSampling(fixedBucketSampler).optDataBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.zeros(new Shape(1)), 10).build()).optLabelBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.ones(new Shape(1)), 10).build()).optLimit(200).build();
dataset.prepare();
Iterator<List<Long>> iterator = fixedBucketSampler.sample(dataset);
long count = 0;
Set<Long> indicesSet = new HashSet<>();
while (iterator.hasNext()) {
List<Long> indices = iterator.next();
indicesSet.addAll(indices);
count += indices.size();
}
Assert.assertEquals(count, dataset.size());
Assert.assertEquals(indicesSet.size(), dataset.size());
fixedBucketSampler = new FixedBucketSampler(10, 5, true);
iterator = fixedBucketSampler.sample(dataset);
count = 0;
indicesSet.clear();
while (iterator.hasNext()) {
List<Long> indices = iterator.next();
indicesSet.addAll(indices);
count = count + indices.size();
}
Assert.assertEquals(count, dataset.size());
Assert.assertEquals(indicesSet.size(), dataset.size());
}
use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class TrtTest method testSerializedEngine.
@Test
public void testSerializedEngine() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
Device device = engine.defaultDevice();
if (!device.isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
String sm = CudaUtils.getComputeCapability(device.getDeviceId());
Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity_" + sm + ".trt")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
try (ZooModel<float[], float[]> model = criteria.loadModel();
Predictor<float[], float[]> predictor = model.newPredictor()) {
float[] data = new float[] { 1, 2, 3, 4 };
float[] ret = predictor.predict(data);
Assert.assertEquals(ret, data);
}
}
Aggregations