Search in sources :

Example 6 with ModelException

use of ai.djl.ModelException in project djl by deepjavalibrary.

the class TrtTest method testTrtOnnx.

@Test
public void testTrtOnnx() throws ModelException, IOException, TranslateException {
    Engine engine;
    try {
        engine = Engine.getEngine("TensorRT");
    } catch (Exception ignore) {
        throw new SkipException("Your os configuration doesn't support TensorRT.");
    }
    if (!engine.defaultDevice().isGpu()) {
        throw new SkipException("TensorRT only support GPU.");
    }
    Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity.onnx")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
    try (ZooModel<float[], float[]> model = criteria.loadModel();
        Predictor<float[], float[]> predictor = model.newPredictor()) {
        float[] data = new float[] { 1, 2, 3, 4 };
        float[] ret = predictor.predict(data);
        Assert.assertEquals(ret, data);
    }
}
Also used : SkipException(org.testng.SkipException) Engine(ai.djl.engine.Engine) SkipException(org.testng.SkipException) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) IOException(java.io.IOException) Test(org.testng.annotations.Test)

Example 7 with ModelException

use of ai.djl.ModelException in project djl by deepjavalibrary.

the class AbstractBenchmark method runBenchmark.

/**
 * Execute benchmark.
 *
 * @param args input raw arguments
 * @return if example execution complete successfully
 */
public final boolean runBenchmark(String[] args) {
    Options options = Arguments.getOptions();
    try {
        if (Arguments.hasHelp(args)) {
            Arguments.printHelp("usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options);
            return true;
        }
        DefaultParser parser = new DefaultParser();
        CommandLine cmd = parser.parse(options, args, null, false);
        Arguments arguments = new Arguments(cmd);
        String engineName = arguments.getEngine();
        Engine engine = Engine.getEngine(engineName);
        long init = System.nanoTime();
        String version = engine.getVersion();
        long loaded = System.nanoTime();
        logger.info(String.format("Load %s (%s) in %.3f ms.", engineName, version, (loaded - init) / 1_000_000f));
        Duration duration = Duration.ofSeconds(arguments.getDuration());
        Object devices;
        if (this instanceof MultithreadedBenchmark) {
            devices = engine.getDevices(arguments.getMaxGpus());
        } else {
            devices = engine.defaultDevice();
        }
        if (arguments.getDuration() != 0) {
            logger.info("Running {} on: {}, duration: {} minutes.", getClass().getSimpleName(), devices, duration.toMinutes());
        } else {
            logger.info("Running {} on: {}.", getClass().getSimpleName(), devices);
        }
        int numOfThreads = arguments.getThreads();
        int iteration = arguments.getIteration();
        if (this instanceof MultithreadedBenchmark) {
            int expected = 10 * numOfThreads;
            if (iteration < expected) {
                iteration = expected;
                logger.info("Iteration is too small for multi-threading benchmark. Adjust to: {}", iteration);
            }
        }
        while (!duration.isNegative()) {
            // Reset Metrics for each test loop.
            Metrics metrics = new Metrics();
            progressBar = new ProgressBar("Iteration", iteration);
            float[] lastResult = predict(arguments, metrics, iteration);
            if (lastResult == null) {
                return false;
            }
            long begin = metrics.getMetric("start").get(0).getValue().longValue();
            long end = metrics.getMetric("end").get(0).getValue().longValue();
            long totalTime = end - begin;
            if (lastResult.length > 3) {
                logger.info("Inference result: [{}, {}, {} ...]", lastResult[0], lastResult[1], lastResult[2]);
            } else {
                logger.info("Inference result: {}", lastResult);
            }
            String throughput = String.format("%.2f", iteration * 1000d / totalTime);
            logger.info("Throughput: {}, completed {} iteration in {} ms.", throughput, iteration, totalTime);
            if (metrics.hasMetric("LoadModel")) {
                long loadModelTime = metrics.getMetric("LoadModel").get(0).getValue().longValue();
                logger.info("Model loading time: {} ms.", String.format("%.3f", loadModelTime / 1000f));
            }
            if (metrics.hasMetric("Inference") && iteration > 1) {
                float totalP50 = metrics.percentile("Total", 50).getValue().longValue() / 1000f;
                float totalP90 = metrics.percentile("Total", 90).getValue().longValue() / 1000f;
                float totalP99 = metrics.percentile("Total", 99).getValue().longValue() / 1000f;
                float p50 = metrics.percentile("Inference", 50).getValue().longValue() / 1000f;
                float p90 = metrics.percentile("Inference", 90).getValue().longValue() / 1000f;
                float p99 = metrics.percentile("Inference", 99).getValue().longValue() / 1000f;
                float preP50 = metrics.percentile("Preprocess", 50).getValue().longValue() / 1000f;
                float preP90 = metrics.percentile("Preprocess", 90).getValue().longValue() / 1000f;
                float preP99 = metrics.percentile("Preprocess", 99).getValue().longValue() / 1000f;
                float postP50 = metrics.percentile("Postprocess", 50).getValue().longValue() / 1000f;
                float postP90 = metrics.percentile("Postprocess", 90).getValue().longValue() / 1000f;
                float postP99 = metrics.percentile("Postprocess", 99).getValue().longValue() / 1000f;
                logger.info(String.format("total P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", totalP50, totalP90, totalP99));
                logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", p50, p90, p99));
                logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", preP50, preP90, preP99));
                logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", postP50, postP90, postP99));
                if (Boolean.getBoolean("collect-memory")) {
                    float heapBeforeModel = metrics.getMetric("Heap").get(0).getValue().longValue();
                    float heapBeforeInference = metrics.getMetric("Heap").get(1).getValue().longValue();
                    float heap = metrics.percentile("Heap", 90).getValue().longValue();
                    float nonHeap = metrics.percentile("NonHeap", 90).getValue().longValue();
                    int mb = 1024 * 1024;
                    logger.info(String.format("heap (base): %.3f MB", heapBeforeModel / mb));
                    logger.info(String.format("heap (model): %.3f MB", heapBeforeInference / mb));
                    logger.info(String.format("heap P90: %.3f MB", heap / mb));
                    logger.info(String.format("nonHeap P90: %.3f MB", nonHeap / mb));
                    if (!System.getProperty("os.name").startsWith("Win")) {
                        float rssBeforeModel = metrics.getMetric("rss").get(0).getValue().longValue();
                        float rssBeforeInference = metrics.getMetric("rss").get(1).getValue().longValue();
                        float rss = metrics.percentile("rss", 90).getValue().longValue();
                        float cpu = metrics.percentile("cpu", 90).getValue().longValue();
                        logger.info(String.format("cpu P90: %.3f %%", cpu));
                        logger.info(String.format("rss (base): %.3f MB", rssBeforeModel / mb));
                        logger.info(String.format("rss (model): %.3f MB", rssBeforeInference / mb));
                        logger.info(String.format("rss P90: %.3f MB", rss / mb));
                    }
                }
            }
            MemoryTrainingListener.dumpMemoryInfo(metrics, arguments.getOutputDir());
            long delta = System.currentTimeMillis() - begin;
            duration = duration.minus(Duration.ofMillis(delta));
            if (!duration.isNegative()) {
                logger.info(duration.toMinutes() + " minutes left");
            }
        }
        return true;
    } catch (ParseException e) {
        Arguments.printHelp(e.getMessage(), options);
    } catch (TranslateException | ModelException | IOException | ClassNotFoundException t) {
        logger.error("Unexpected error", t);
    }
    return false;
}
Also used : Options(org.apache.commons.cli.Options) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) Duration(java.time.Duration) IOException(java.io.IOException) CommandLine(org.apache.commons.cli.CommandLine) Metrics(ai.djl.metric.Metrics) ParseException(org.apache.commons.cli.ParseException) ProgressBar(ai.djl.training.util.ProgressBar) Engine(ai.djl.engine.Engine) DefaultParser(org.apache.commons.cli.DefaultParser)

Example 8 with ModelException

use of ai.djl.ModelException in project djl-demo by deepjavalibrary.

the class Handler method handleRequest.

@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
    LambdaLogger logger = context.getLogger();
    String input = Utils.toString(is);
    try {
        Request request = GSON.fromJson(input, Request.class);
        String base64Img = request.getImageData().split(",")[1];
        byte[] imgBytes = Base64.getDecoder().decode(base64Img);
        Image img;
        try (ByteArrayInputStream bis = new ByteArrayInputStream(imgBytes)) {
            ImageFactory factory = ImageFactory.getInstance();
            img = factory.fromInputStream(bis);
        }
        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optFlag(Image.Flag.GRAYSCALE).optApplySoftmax(true).build();
        Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip").optTranslator(translator).build();
        ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
        try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
            List<Classifications.Classification> result = predictor.predict(img).topK(5);
            os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
        }
    } catch (RuntimeException | ModelException | TranslateException e) {
        logger.log("Failed handle input: " + input);
        logger.log(e.toString());
        String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
        os.write(msg.getBytes(StandardCharsets.UTF_8));
    }
}
Also used : Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) Image(ai.djl.modality.cv.Image) ImageFactory(ai.djl.modality.cv.ImageFactory) ByteArrayInputStream(java.io.ByteArrayInputStream) LambdaLogger(com.amazonaws.services.lambda.runtime.LambdaLogger)

Aggregations

ModelException (ai.djl.ModelException)8 TranslateException (ai.djl.translate.TranslateException)7 IOException (java.io.IOException)5 Engine (ai.djl.engine.Engine)4 Classifications (ai.djl.modality.Classifications)3 Image (ai.djl.modality.cv.Image)3 SkipException (org.testng.SkipException)3 Test (org.testng.annotations.Test)3 ToTensor (ai.djl.modality.cv.transform.ToTensor)2 LambdaLogger (com.amazonaws.services.lambda.runtime.LambdaLogger)2 Device (ai.djl.Device)1 Metrics (ai.djl.metric.Metrics)1 Input (ai.djl.modality.Input)1 Output (ai.djl.modality.Output)1 ImageFactory (ai.djl.modality.cv.ImageFactory)1 ImageClassificationTranslator (ai.djl.modality.cv.translator.ImageClassificationTranslator)1 Criteria (ai.djl.repository.zoo.Criteria)1 ProgressBar (ai.djl.training.util.ProgressBar)1 ByteArrayInputStream (java.io.ByteArrayInputStream)1 InputStream (java.io.InputStream)1