use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class TrtTest method testTrtOnnx.
@Test
public void testTrtOnnx() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity.onnx")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
try (ZooModel<float[], float[]> model = criteria.loadModel();
Predictor<float[], float[]> predictor = model.newPredictor()) {
float[] data = new float[] { 1, 2, 3, 4 };
float[] ret = predictor.predict(data);
Assert.assertEquals(ret, data);
}
}
use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class AbstractBenchmark method runBenchmark.
/**
* Execute benchmark.
*
* @param args input raw arguments
* @return if example execution complete successfully
*/
public final boolean runBenchmark(String[] args) {
Options options = Arguments.getOptions();
try {
if (Arguments.hasHelp(args)) {
Arguments.printHelp("usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options);
return true;
}
DefaultParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, null, false);
Arguments arguments = new Arguments(cmd);
String engineName = arguments.getEngine();
Engine engine = Engine.getEngine(engineName);
long init = System.nanoTime();
String version = engine.getVersion();
long loaded = System.nanoTime();
logger.info(String.format("Load %s (%s) in %.3f ms.", engineName, version, (loaded - init) / 1_000_000f));
Duration duration = Duration.ofSeconds(arguments.getDuration());
Object devices;
if (this instanceof MultithreadedBenchmark) {
devices = engine.getDevices(arguments.getMaxGpus());
} else {
devices = engine.defaultDevice();
}
if (arguments.getDuration() != 0) {
logger.info("Running {} on: {}, duration: {} minutes.", getClass().getSimpleName(), devices, duration.toMinutes());
} else {
logger.info("Running {} on: {}.", getClass().getSimpleName(), devices);
}
int numOfThreads = arguments.getThreads();
int iteration = arguments.getIteration();
if (this instanceof MultithreadedBenchmark) {
int expected = 10 * numOfThreads;
if (iteration < expected) {
iteration = expected;
logger.info("Iteration is too small for multi-threading benchmark. Adjust to: {}", iteration);
}
}
while (!duration.isNegative()) {
// Reset Metrics for each test loop.
Metrics metrics = new Metrics();
progressBar = new ProgressBar("Iteration", iteration);
float[] lastResult = predict(arguments, metrics, iteration);
if (lastResult == null) {
return false;
}
long begin = metrics.getMetric("start").get(0).getValue().longValue();
long end = metrics.getMetric("end").get(0).getValue().longValue();
long totalTime = end - begin;
if (lastResult.length > 3) {
logger.info("Inference result: [{}, {}, {} ...]", lastResult[0], lastResult[1], lastResult[2]);
} else {
logger.info("Inference result: {}", lastResult);
}
String throughput = String.format("%.2f", iteration * 1000d / totalTime);
logger.info("Throughput: {}, completed {} iteration in {} ms.", throughput, iteration, totalTime);
if (metrics.hasMetric("LoadModel")) {
long loadModelTime = metrics.getMetric("LoadModel").get(0).getValue().longValue();
logger.info("Model loading time: {} ms.", String.format("%.3f", loadModelTime / 1000f));
}
if (metrics.hasMetric("Inference") && iteration > 1) {
float totalP50 = metrics.percentile("Total", 50).getValue().longValue() / 1000f;
float totalP90 = metrics.percentile("Total", 90).getValue().longValue() / 1000f;
float totalP99 = metrics.percentile("Total", 99).getValue().longValue() / 1000f;
float p50 = metrics.percentile("Inference", 50).getValue().longValue() / 1000f;
float p90 = metrics.percentile("Inference", 90).getValue().longValue() / 1000f;
float p99 = metrics.percentile("Inference", 99).getValue().longValue() / 1000f;
float preP50 = metrics.percentile("Preprocess", 50).getValue().longValue() / 1000f;
float preP90 = metrics.percentile("Preprocess", 90).getValue().longValue() / 1000f;
float preP99 = metrics.percentile("Preprocess", 99).getValue().longValue() / 1000f;
float postP50 = metrics.percentile("Postprocess", 50).getValue().longValue() / 1000f;
float postP90 = metrics.percentile("Postprocess", 90).getValue().longValue() / 1000f;
float postP99 = metrics.percentile("Postprocess", 99).getValue().longValue() / 1000f;
logger.info(String.format("total P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", totalP50, totalP90, totalP99));
logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", p50, p90, p99));
logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", preP50, preP90, preP99));
logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", postP50, postP90, postP99));
if (Boolean.getBoolean("collect-memory")) {
float heapBeforeModel = metrics.getMetric("Heap").get(0).getValue().longValue();
float heapBeforeInference = metrics.getMetric("Heap").get(1).getValue().longValue();
float heap = metrics.percentile("Heap", 90).getValue().longValue();
float nonHeap = metrics.percentile("NonHeap", 90).getValue().longValue();
int mb = 1024 * 1024;
logger.info(String.format("heap (base): %.3f MB", heapBeforeModel / mb));
logger.info(String.format("heap (model): %.3f MB", heapBeforeInference / mb));
logger.info(String.format("heap P90: %.3f MB", heap / mb));
logger.info(String.format("nonHeap P90: %.3f MB", nonHeap / mb));
if (!System.getProperty("os.name").startsWith("Win")) {
float rssBeforeModel = metrics.getMetric("rss").get(0).getValue().longValue();
float rssBeforeInference = metrics.getMetric("rss").get(1).getValue().longValue();
float rss = metrics.percentile("rss", 90).getValue().longValue();
float cpu = metrics.percentile("cpu", 90).getValue().longValue();
logger.info(String.format("cpu P90: %.3f %%", cpu));
logger.info(String.format("rss (base): %.3f MB", rssBeforeModel / mb));
logger.info(String.format("rss (model): %.3f MB", rssBeforeInference / mb));
logger.info(String.format("rss P90: %.3f MB", rss / mb));
}
}
}
MemoryTrainingListener.dumpMemoryInfo(metrics, arguments.getOutputDir());
long delta = System.currentTimeMillis() - begin;
duration = duration.minus(Duration.ofMillis(delta));
if (!duration.isNegative()) {
logger.info(duration.toMinutes() + " minutes left");
}
}
return true;
} catch (ParseException e) {
Arguments.printHelp(e.getMessage(), options);
} catch (TranslateException | ModelException | IOException | ClassNotFoundException t) {
logger.error("Unexpected error", t);
}
return false;
}
use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.
the class ConsumerLoop method run.
@Override
public void run() {
try {
consumer.subscribe(topics);
while (true) {
ConsumerRecords<Void, String> records = consumer.poll(Duration.ofMillis(Long.MAX_VALUE));
for (ConsumerRecord<Void, String> record : records) {
Map<String, Object> data = new HashMap<>();
data.put("partition", record.partition());
data.put("offset", record.offset());
data.put("value", record.value());
// make prediction on text data
Classifications result = predictor.predict(record.value());
data.put("prediction", result.toString());
System.out.println("content: " + data.get("value"));
System.out.println("prediction: " + data.get("prediction"));
}
}
} catch (WakeupException | TranslateException e) {
// ignore for shutdown
} finally {
consumer.close();
}
}
use of ai.djl.translate.TranslateException in project djl-demo by deepjavalibrary.
the class Handler method handleRequest.
@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
LambdaLogger logger = context.getLogger();
String input = Utils.toString(is);
try {
Request request = GSON.fromJson(input, Request.class);
String base64Img = request.getImageData().split(",")[1];
byte[] imgBytes = Base64.getDecoder().decode(base64Img);
Image img;
try (ByteArrayInputStream bis = new ByteArrayInputStream(imgBytes)) {
ImageFactory factory = ImageFactory.getInstance();
img = factory.fromInputStream(bis);
}
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optFlag(Image.Flag.GRAYSCALE).optApplySoftmax(true).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip").optTranslator(translator).build();
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
List<Classifications.Classification> result = predictor.predict(img).topK(5);
os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
}
} catch (RuntimeException | ModelException | TranslateException e) {
logger.log("Failed handle input: " + input);
logger.log(e.toString());
String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
os.write(msg.getBytes(StandardCharsets.UTF_8));
}
}
use of ai.djl.translate.TranslateException in project build-your-own-social-media-analytics-with-apache-kafka by scholzj.
the class TopologyProducer method buildTopology.
@Produces
public Topology buildTopology() {
final TweetSerde tweetSerde = new TweetSerde();
try {
Criteria<String, Classifications> criteria = Criteria.builder().optApplication(Application.NLP.SENTIMENT_ANALYSIS).setTypes(String.class, Classifications.class).build();
predictor = ModelZoo.loadModel(criteria).newPredictor();
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
LOG.error("Failed to load model", e);
throw new RuntimeException("Failed to load model", e);
}
final StreamsBuilder builder = new StreamsBuilder();
builder.stream(SOURCE_TOPIC, Consumed.with(Serdes.ByteArray(), tweetSerde)).flatMapValues(value -> {
if (value.getRetweetedStatus() != null) {
// We ignore retweets => we do not want alert for every retweet
return List.of();
} else {
String tweet = value.getText();
try {
Classifications classifications = predictor.predict(tweet);
String statusUrl = "https://twitter.com/" + value.getUser().getScreenName() + "/status/" + value.getId();
String alert = String.format("The following tweet was classified as %s with %2.2f%% probability: %s", classifications.best().getClassName().toLowerCase(Locale.ENGLISH), classifications.best().getProbability() * 100, statusUrl);
// We care nly about strong results where probability is > 50%
if (classifications.best().getProbability() > 0.50) {
LOG.infov("Tweeting: {0}", alert);
return List.of(alert);
} else {
LOG.infov("Not tweeting: {0}", alert);
return List.of();
}
} catch (TranslateException e) {
LOG.errorv("Failed to classify the tweet {0}", value);
return List.of();
}
}
}).peek((key, value) -> LOG.infov("{0}", value)).to(TARGET_TOPIC, Produced.with(Serdes.ByteArray(), Serdes.String()));
return builder.build();
}
Aggregations