use of ai.djl.modality.Classifications.Classification in project djl-serving by deepjavalibrary.
the class ModelServerTest method testPredictions.
private void testPredictions(Channel channel, String[] targets) throws InterruptedException {
for (String target : targets) {
reset();
DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, target);
req.content().writeBytes(testImage);
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM);
channel.writeAndFlush(req);
latch.await();
Type type = new TypeToken<List<Classification>>() {
}.getType();
List<Classification> classifications = JsonUtils.GSON.fromJson(result, type);
assertEquals(classifications.get(0).getClassName(), "0");
}
}
use of ai.djl.modality.Classifications.Classification in project djl-serving by deepjavalibrary.
the class ModelServerTest method testInvocationsMultipart.
private void testInvocationsMultipart(Channel channel) throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException {
reset();
DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations?model_name=mlp");
ByteBuf content = Unpooled.buffer(testImage.length);
content.writeBytes(testImage);
HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(req, true);
encoder.addBodyAttribute("test", "test");
MemoryFileUpload body = new MemoryFileUpload("data", "0.png", "image/png", null, null, testImage.length);
body.setContent(content);
encoder.addBodyHttpData(body);
channel.writeAndFlush(encoder.finalizeRequest());
if (encoder.isChunked()) {
channel.writeAndFlush(encoder).sync();
}
latch.await();
Type type = new TypeToken<List<Classification>>() {
}.getType();
List<Classification> classifications = JsonUtils.GSON.fromJson(result, type);
assertEquals(classifications.get(0).getClassName(), "0");
}
use of ai.djl.modality.Classifications.Classification in project djl-serving by deepjavalibrary.
the class ModelServerTest method testInvocations.
private void testInvocations(Channel channel) throws InterruptedException {
reset();
DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations");
req.content().writeBytes(testImage);
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM);
channel.writeAndFlush(req);
latch.await();
Type type = new TypeToken<List<Classification>>() {
}.getType();
List<Classification> classifications = JsonUtils.GSON.fromJson(result, type);
assertEquals(classifications.get(0).getClassName(), "0");
}
use of ai.djl.modality.Classifications.Classification in project djl by deepjavalibrary.
the class CustomTranslatorTest method testSsdTranslator.
@Test
public void testSsdTranslator() throws IOException, ModelException, TranslateException {
TestRequirements.engine("MXNet");
Criteria<Image, DetectedObjects> c = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optArtifactId("ai.djl.mxnet:ssd").build();
String modelUrl;
try (ZooModel<Image, DetectedObjects> model = c.loadModel()) {
modelUrl = model.getModelPath().toUri().toURL().toString();
}
Criteria<Input, Output> criteria = Criteria.builder().setTypes(Input.class, Output.class).optModelUrls(modelUrl).optArgument("width", 512).optArgument("height", 512).optArgument("resize", true).optArgument("rescale", true).optArgument("synsetFileName", "classes.txt").optArgument("translatorFactory", "ai.djl.modality.cv.translator.SingleShotDetectionTranslatorFactory").optModelName("ssd_512_resnet50_v1_voc").build();
Path imageFile = Paths.get("../examples/src/test/resources/dog_bike_car.jpg");
byte[] buf;
try (InputStream is = Files.newInputStream(imageFile)) {
buf = Utils.toByteArray(is);
}
try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Input input = new Input();
input.add(buf);
Output output = predictor.predict(input);
Assert.assertEquals(output.getCode(), 200);
String content = output.getAsString(0);
Type type = new TypeToken<List<Classification>>() {
}.getType();
List<Classification> result = JsonUtils.GSON.fromJson(content, type);
Assert.assertEquals(result.get(0).getClassName(), "car");
}
}
use of ai.djl.modality.Classifications.Classification in project djl by deepjavalibrary.
the class CustomTranslatorTest method runImageClassification.
private void runImageClassification(Application application, Map<String, Object> arguments, String translatorName) throws IOException, ModelException, TranslateException {
Criteria<Input, Output> criteria = Criteria.builder().setTypes(Input.class, Output.class).optApplication(application).optArguments(arguments).optModelPath(modelDir).build();
try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Translator<Input, Output> translator = model.getTranslator();
Assert.assertEquals(translator.getClass().getSimpleName(), translatorName);
Input input = new Input();
input.add("data", data);
Output output = predictor.predict(input);
Assert.assertEquals(output.getCode(), 200);
String content = output.getAsString(0);
Type type = new TypeToken<List<Classification>>() {
}.getType();
List<Classification> result = JsonUtils.GSON.fromJson(content, type);
Assert.assertEquals(result.get(0).getClassName(), "0");
}
}
Aggregations