Search in sources :

Example 1 with Classification

use of ai.djl.modality.Classifications.Classification in project djl-serving by deepjavalibrary.

the class ModelServerTest method testPredictions.

private void testPredictions(Channel channel, String[] targets) throws InterruptedException {
    for (String target : targets) {
        reset();
        DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, target);
        req.content().writeBytes(testImage);
        HttpUtil.setContentLength(req, req.content().readableBytes());
        req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM);
        channel.writeAndFlush(req);
        latch.await();
        Type type = new TypeToken<List<Classification>>() {
        }.getType();
        List<Classification> classifications = JsonUtils.GSON.fromJson(result, type);
        assertEquals(classifications.get(0).getClassName(), "0");
    }
}
Also used : Type(java.lang.reflect.Type) DefaultFullHttpRequest(io.netty.handler.codec.http.DefaultFullHttpRequest) Classification(ai.djl.modality.Classifications.Classification) List(java.util.List) ArrayList(java.util.ArrayList)

Example 2 with Classification

use of ai.djl.modality.Classifications.Classification in project djl-serving by deepjavalibrary.

the class ModelServerTest method testInvocationsMultipart.

private void testInvocationsMultipart(Channel channel) throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException {
    reset();
    DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations?model_name=mlp");
    ByteBuf content = Unpooled.buffer(testImage.length);
    content.writeBytes(testImage);
    HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(req, true);
    encoder.addBodyAttribute("test", "test");
    MemoryFileUpload body = new MemoryFileUpload("data", "0.png", "image/png", null, null, testImage.length);
    body.setContent(content);
    encoder.addBodyHttpData(body);
    channel.writeAndFlush(encoder.finalizeRequest());
    if (encoder.isChunked()) {
        channel.writeAndFlush(encoder).sync();
    }
    latch.await();
    Type type = new TypeToken<List<Classification>>() {
    }.getType();
    List<Classification> classifications = JsonUtils.GSON.fromJson(result, type);
    assertEquals(classifications.get(0).getClassName(), "0");
}
Also used : Type(java.lang.reflect.Type) DefaultFullHttpRequest(io.netty.handler.codec.http.DefaultFullHttpRequest) HttpPostRequestEncoder(io.netty.handler.codec.http.multipart.HttpPostRequestEncoder) Classification(ai.djl.modality.Classifications.Classification) MemoryFileUpload(io.netty.handler.codec.http.multipart.MemoryFileUpload) List(java.util.List) ArrayList(java.util.ArrayList) ByteBuf(io.netty.buffer.ByteBuf)

Example 3 with Classification

use of ai.djl.modality.Classifications.Classification in project djl-serving by deepjavalibrary.

the class ModelServerTest method testInvocations.

private void testInvocations(Channel channel) throws InterruptedException {
    reset();
    DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations");
    req.content().writeBytes(testImage);
    HttpUtil.setContentLength(req, req.content().readableBytes());
    req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM);
    channel.writeAndFlush(req);
    latch.await();
    Type type = new TypeToken<List<Classification>>() {
    }.getType();
    List<Classification> classifications = JsonUtils.GSON.fromJson(result, type);
    assertEquals(classifications.get(0).getClassName(), "0");
}
Also used : Type(java.lang.reflect.Type) DefaultFullHttpRequest(io.netty.handler.codec.http.DefaultFullHttpRequest) Classification(ai.djl.modality.Classifications.Classification) List(java.util.List) ArrayList(java.util.ArrayList)

Example 4 with Classification

use of ai.djl.modality.Classifications.Classification in project djl by deepjavalibrary.

the class CustomTranslatorTest method testSsdTranslator.

@Test
public void testSsdTranslator() throws IOException, ModelException, TranslateException {
    TestRequirements.engine("MXNet");
    Criteria<Image, DetectedObjects> c = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optArtifactId("ai.djl.mxnet:ssd").build();
    String modelUrl;
    try (ZooModel<Image, DetectedObjects> model = c.loadModel()) {
        modelUrl = model.getModelPath().toUri().toURL().toString();
    }
    Criteria<Input, Output> criteria = Criteria.builder().setTypes(Input.class, Output.class).optModelUrls(modelUrl).optArgument("width", 512).optArgument("height", 512).optArgument("resize", true).optArgument("rescale", true).optArgument("synsetFileName", "classes.txt").optArgument("translatorFactory", "ai.djl.modality.cv.translator.SingleShotDetectionTranslatorFactory").optModelName("ssd_512_resnet50_v1_voc").build();
    Path imageFile = Paths.get("../examples/src/test/resources/dog_bike_car.jpg");
    byte[] buf;
    try (InputStream is = Files.newInputStream(imageFile)) {
        buf = Utils.toByteArray(is);
    }
    try (ZooModel<Input, Output> model = criteria.loadModel();
        Predictor<Input, Output> predictor = model.newPredictor()) {
        Input input = new Input();
        input.add(buf);
        Output output = predictor.predict(input);
        Assert.assertEquals(output.getCode(), 200);
        String content = output.getAsString(0);
        Type type = new TypeToken<List<Classification>>() {
        }.getType();
        List<Classification> result = JsonUtils.GSON.fromJson(content, type);
        Assert.assertEquals(result.get(0).getClassName(), "car");
    }
}
Also used : Path(java.nio.file.Path) ByteArrayInputStream(java.io.ByteArrayInputStream) InputStream(java.io.InputStream) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) Image(ai.djl.modality.cv.Image) Input(ai.djl.modality.Input) Type(java.lang.reflect.Type) Output(ai.djl.modality.Output) Classification(ai.djl.modality.Classifications.Classification) NDList(ai.djl.ndarray.NDList) List(java.util.List) Test(org.testng.annotations.Test)

Example 5 with Classification

use of ai.djl.modality.Classifications.Classification in project djl by deepjavalibrary.

the class CustomTranslatorTest method runImageClassification.

private void runImageClassification(Application application, Map<String, Object> arguments, String translatorName) throws IOException, ModelException, TranslateException {
    Criteria<Input, Output> criteria = Criteria.builder().setTypes(Input.class, Output.class).optApplication(application).optArguments(arguments).optModelPath(modelDir).build();
    try (ZooModel<Input, Output> model = criteria.loadModel();
        Predictor<Input, Output> predictor = model.newPredictor()) {
        Translator<Input, Output> translator = model.getTranslator();
        Assert.assertEquals(translator.getClass().getSimpleName(), translatorName);
        Input input = new Input();
        input.add("data", data);
        Output output = predictor.predict(input);
        Assert.assertEquals(output.getCode(), 200);
        String content = output.getAsString(0);
        Type type = new TypeToken<List<Classification>>() {
        }.getType();
        List<Classification> result = JsonUtils.GSON.fromJson(content, type);
        Assert.assertEquals(result.get(0).getClassName(), "0");
    }
}
Also used : Input(ai.djl.modality.Input) Type(java.lang.reflect.Type) Output(ai.djl.modality.Output) Classification(ai.djl.modality.Classifications.Classification) NDList(ai.djl.ndarray.NDList) List(java.util.List)

Aggregations

Classification (ai.djl.modality.Classifications.Classification)6 Type (java.lang.reflect.Type)6 List (java.util.List)6 DefaultFullHttpRequest (io.netty.handler.codec.http.DefaultFullHttpRequest)4 ArrayList (java.util.ArrayList)3 Input (ai.djl.modality.Input)2 Output (ai.djl.modality.Output)2 NDList (ai.djl.ndarray.NDList)2 Image (ai.djl.modality.cv.Image)1 DetectedObjects (ai.djl.modality.cv.output.DetectedObjects)1 ByteBuf (io.netty.buffer.ByteBuf)1 HttpPostRequestEncoder (io.netty.handler.codec.http.multipart.HttpPostRequestEncoder)1 MemoryFileUpload (io.netty.handler.codec.http.multipart.MemoryFileUpload)1 ByteArrayInputStream (java.io.ByteArrayInputStream)1 InputStream (java.io.InputStream)1 Path (java.nio.file.Path)1 Test (org.testng.annotations.Test)1