use of ai.djl.ModelException in project djl by deepjavalibrary.
the class TrtTest method testTrtOnnx.
@Test
public void testTrtOnnx() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity.onnx")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
try (ZooModel<float[], float[]> model = criteria.loadModel();
Predictor<float[], float[]> predictor = model.newPredictor()) {
float[] data = new float[] { 1, 2, 3, 4 };
float[] ret = predictor.predict(data);
Assert.assertEquals(ret, data);
}
}
use of ai.djl.ModelException in project djl by deepjavalibrary.
the class AbstractBenchmark method runBenchmark.
/**
* Execute benchmark.
*
* @param args input raw arguments
* @return if example execution complete successfully
*/
public final boolean runBenchmark(String[] args) {
Options options = Arguments.getOptions();
try {
if (Arguments.hasHelp(args)) {
Arguments.printHelp("usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options);
return true;
}
DefaultParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, null, false);
Arguments arguments = new Arguments(cmd);
String engineName = arguments.getEngine();
Engine engine = Engine.getEngine(engineName);
long init = System.nanoTime();
String version = engine.getVersion();
long loaded = System.nanoTime();
logger.info(String.format("Load %s (%s) in %.3f ms.", engineName, version, (loaded - init) / 1_000_000f));
Duration duration = Duration.ofSeconds(arguments.getDuration());
Object devices;
if (this instanceof MultithreadedBenchmark) {
devices = engine.getDevices(arguments.getMaxGpus());
} else {
devices = engine.defaultDevice();
}
if (arguments.getDuration() != 0) {
logger.info("Running {} on: {}, duration: {} minutes.", getClass().getSimpleName(), devices, duration.toMinutes());
} else {
logger.info("Running {} on: {}.", getClass().getSimpleName(), devices);
}
int numOfThreads = arguments.getThreads();
int iteration = arguments.getIteration();
if (this instanceof MultithreadedBenchmark) {
int expected = 10 * numOfThreads;
if (iteration < expected) {
iteration = expected;
logger.info("Iteration is too small for multi-threading benchmark. Adjust to: {}", iteration);
}
}
while (!duration.isNegative()) {
// Reset Metrics for each test loop.
Metrics metrics = new Metrics();
progressBar = new ProgressBar("Iteration", iteration);
float[] lastResult = predict(arguments, metrics, iteration);
if (lastResult == null) {
return false;
}
long begin = metrics.getMetric("start").get(0).getValue().longValue();
long end = metrics.getMetric("end").get(0).getValue().longValue();
long totalTime = end - begin;
if (lastResult.length > 3) {
logger.info("Inference result: [{}, {}, {} ...]", lastResult[0], lastResult[1], lastResult[2]);
} else {
logger.info("Inference result: {}", lastResult);
}
String throughput = String.format("%.2f", iteration * 1000d / totalTime);
logger.info("Throughput: {}, completed {} iteration in {} ms.", throughput, iteration, totalTime);
if (metrics.hasMetric("LoadModel")) {
long loadModelTime = metrics.getMetric("LoadModel").get(0).getValue().longValue();
logger.info("Model loading time: {} ms.", String.format("%.3f", loadModelTime / 1000f));
}
if (metrics.hasMetric("Inference") && iteration > 1) {
float totalP50 = metrics.percentile("Total", 50).getValue().longValue() / 1000f;
float totalP90 = metrics.percentile("Total", 90).getValue().longValue() / 1000f;
float totalP99 = metrics.percentile("Total", 99).getValue().longValue() / 1000f;
float p50 = metrics.percentile("Inference", 50).getValue().longValue() / 1000f;
float p90 = metrics.percentile("Inference", 90).getValue().longValue() / 1000f;
float p99 = metrics.percentile("Inference", 99).getValue().longValue() / 1000f;
float preP50 = metrics.percentile("Preprocess", 50).getValue().longValue() / 1000f;
float preP90 = metrics.percentile("Preprocess", 90).getValue().longValue() / 1000f;
float preP99 = metrics.percentile("Preprocess", 99).getValue().longValue() / 1000f;
float postP50 = metrics.percentile("Postprocess", 50).getValue().longValue() / 1000f;
float postP90 = metrics.percentile("Postprocess", 90).getValue().longValue() / 1000f;
float postP99 = metrics.percentile("Postprocess", 99).getValue().longValue() / 1000f;
logger.info(String.format("total P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", totalP50, totalP90, totalP99));
logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", p50, p90, p99));
logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", preP50, preP90, preP99));
logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", postP50, postP90, postP99));
if (Boolean.getBoolean("collect-memory")) {
float heapBeforeModel = metrics.getMetric("Heap").get(0).getValue().longValue();
float heapBeforeInference = metrics.getMetric("Heap").get(1).getValue().longValue();
float heap = metrics.percentile("Heap", 90).getValue().longValue();
float nonHeap = metrics.percentile("NonHeap", 90).getValue().longValue();
int mb = 1024 * 1024;
logger.info(String.format("heap (base): %.3f MB", heapBeforeModel / mb));
logger.info(String.format("heap (model): %.3f MB", heapBeforeInference / mb));
logger.info(String.format("heap P90: %.3f MB", heap / mb));
logger.info(String.format("nonHeap P90: %.3f MB", nonHeap / mb));
if (!System.getProperty("os.name").startsWith("Win")) {
float rssBeforeModel = metrics.getMetric("rss").get(0).getValue().longValue();
float rssBeforeInference = metrics.getMetric("rss").get(1).getValue().longValue();
float rss = metrics.percentile("rss", 90).getValue().longValue();
float cpu = metrics.percentile("cpu", 90).getValue().longValue();
logger.info(String.format("cpu P90: %.3f %%", cpu));
logger.info(String.format("rss (base): %.3f MB", rssBeforeModel / mb));
logger.info(String.format("rss (model): %.3f MB", rssBeforeInference / mb));
logger.info(String.format("rss P90: %.3f MB", rss / mb));
}
}
}
MemoryTrainingListener.dumpMemoryInfo(metrics, arguments.getOutputDir());
long delta = System.currentTimeMillis() - begin;
duration = duration.minus(Duration.ofMillis(delta));
if (!duration.isNegative()) {
logger.info(duration.toMinutes() + " minutes left");
}
}
return true;
} catch (ParseException e) {
Arguments.printHelp(e.getMessage(), options);
} catch (TranslateException | ModelException | IOException | ClassNotFoundException t) {
logger.error("Unexpected error", t);
}
return false;
}
use of ai.djl.ModelException in project djl-demo by deepjavalibrary.
the class Handler method handleRequest.
@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
LambdaLogger logger = context.getLogger();
String input = Utils.toString(is);
try {
Request request = GSON.fromJson(input, Request.class);
String base64Img = request.getImageData().split(",")[1];
byte[] imgBytes = Base64.getDecoder().decode(base64Img);
Image img;
try (ByteArrayInputStream bis = new ByteArrayInputStream(imgBytes)) {
ImageFactory factory = ImageFactory.getInstance();
img = factory.fromInputStream(bis);
}
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optFlag(Image.Flag.GRAYSCALE).optApplySoftmax(true).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip").optTranslator(translator).build();
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
List<Classifications.Classification> result = predictor.predict(img).topK(5);
os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
}
} catch (RuntimeException | ModelException | TranslateException e) {
logger.log("Failed handle input: " + input);
logger.log(e.toString());
String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
os.write(msg.getBytes(StandardCharsets.UTF_8));
}
}
Aggregations