use of ai.djl.Device in project djl-demo by deepjavalibrary.
the class CanaryTest method main.
public static void main(String[] args) throws IOException, ModelException, TranslateException {
logger.info("");
logger.info("----------Environment Variables----------");
System.getenv().forEach((k, v) -> logger.info(k + ": " + v));
logger.info("");
logger.info("----------Default Engine----------");
Engine.debugEnvironment();
logger.info("");
logger.info("----------Device information----------");
int gpuCount = CudaUtils.getGpuCount();
logger.info("GPU Count: {}", gpuCount);
if (gpuCount > 0) {
logger.info("CUDA: {}", CudaUtils.getCudaVersionString());
logger.info("ARCH: {}", CudaUtils.getComputeCapability(0));
}
String djlEngine = System.getenv("DJL_ENGINE");
if (djlEngine == null) {
djlEngine = "mxnet-native-auto";
}
Device device = NDManager.newBaseManager().getDevice();
if (djlEngine.contains("-native-cu") && !device.isGpu()) {
throw new AssertionError("Expecting load engine on GPU.");
} else if (djlEngine.startsWith("tensorrt")) {
testTensorrt();
return;
} else if (djlEngine.startsWith("onnxruntime")) {
testOnnxRuntime();
return;
} else if (djlEngine.startsWith("xgboost")) {
testXgboost();
return;
} else if (djlEngine.startsWith("tflite")) {
testTflite();
return;
} else if (djlEngine.startsWith("python")) {
testPython();
return;
} else if (djlEngine.startsWith("dlr")) {
testDlr();
// similar to DLR, fastText and SentencePiece only support Mac and Ubuntu 16.04+
testFastText();
testSentencePiece();
return;
} else if (djlEngine.startsWith("paddle")) {
testPaddle();
return;
}
logger.info("");
logger.info("----------Test inference----------");
String url = "https://resources.djl.ai/images/dog_bike_car.jpg";
Image img = ImageFactory.getInstance().fromUrl(url);
String backbone = "resnet50";
Map<String, String> options = null;
if ("TensorFlow".equals(Engine.getInstance().getEngineName())) {
backbone = "mobilenet_v2";
options = new ConcurrentHashMap<>();
options.put("Tags", "");
}
Criteria<Image, DetectedObjects> criteria = Criteria.builder().optApplication(Application.CV.OBJECT_DETECTION).setTypes(Image.class, DetectedObjects.class).optFilter("backbone", backbone).optOptions(options).build();
try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
logger.info("{}", detection);
}
}
}
use of ai.djl.Device in project djl-serving by deepjavalibrary.
the class PyProcess method predict.
Output predict(Input inputs) throws TranslateException {
try {
if (inputs.getProperty("handler", null) == null) {
String handler = pyEnv.getHandler();
if (handler != null) {
inputs.addProperty("handler", handler);
}
}
Device device = model.getNDManager().getDevice();
inputs.addProperty("device_id", String.valueOf(device.getDeviceId()));
return connection.send(inputs);
} catch (ExecutionException | InterruptedException | TimeoutException e) {
throw new TranslateException(e);
}
}
use of ai.djl.Device in project djl-serving by deepjavalibrary.
the class ModelServer method initWorkflows.
private void initWorkflows() throws IOException, URISyntaxException, ModelNotFoundException, BadWorkflowException, MalformedModelException {
Set<String> startupWorkflows = ModelManager.getInstance().getStartupWorkflows();
String loadWorkflows = configManager.getLoadWorkflows();
if (loadWorkflows == null || loadWorkflows.isEmpty()) {
return;
}
ModelManager modelManager = ModelManager.getInstance();
String[] urls = loadWorkflows.split("[, ]+");
for (String url : urls) {
logger.info("Initializing workflow: {}", url);
Matcher matcher = MODEL_STORE_PATTERN.matcher(url);
if (!matcher.matches()) {
throw new AssertionError("Invalid model store url: " + url);
}
String endpoint = matcher.group(2);
String workflowUrlString = matcher.group(3);
Device[] devices = { null };
String workflowName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
workflowName = tokens[0];
if (tokens.length > 1) {
devices = parseDevices(tokens[1], Engine.getInstance());
}
} else {
workflowName = ModelInfo.inferModelNameFromUrl(workflowUrlString);
}
URL workflowUrl = new URL(workflowUrlString);
Workflow workflow = WorkflowDefinition.parse(workflowUrl.toURI(), workflowUrl.openStream()).toWorkflow();
Device[] finalDevices = devices;
CompletableFuture<Void> f = modelManager.registerWorkflow(workflow).thenAccept(v -> {
for (Device device : finalDevices) {
modelManager.scaleWorkers(workflow, device, 1, -1);
}
}).exceptionally(t -> {
logger.error("Failed register workflow", t);
// response (health check)
try {
Thread.sleep(3000);
} catch (InterruptedException ignore) {
// ignore
}
stop();
return null;
});
if (configManager.waitModelLoading()) {
f.join();
}
startupWorkflows.add(workflowName);
}
}
use of ai.djl.Device in project djl-serving by deepjavalibrary.
the class InferenceRequestHandler method predict.
private void predict(ChannelHandlerContext ctx, FullHttpRequest req, Input input, String workflowName, String version) throws ModelNotFoundException {
ModelManager modelManager = ModelManager.getInstance();
ConfigManager config = ConfigManager.getInstance();
Workflow workflow = modelManager.getWorkflow(workflowName, version, true);
if (workflow == null) {
String regex = config.getModelUrlPattern();
if (regex == null) {
throw new ModelNotFoundException("Model or workflow not found: " + workflowName);
}
String modelUrl = input.getProperty("model_url", null);
if (modelUrl == null) {
modelUrl = input.getAsString("model_url");
if (modelUrl == null) {
throw new ModelNotFoundException("Parameter model_url is required.");
}
if (!modelUrl.matches(regex)) {
throw new ModelNotFoundException("Permission denied: " + modelUrl);
}
}
String engineName = input.getProperty("engine_name", null);
String deviceName = input.getProperty("device", "-1");
Engine engine = engineName != null ? Engine.getEngine(engineName) : Engine.getInstance();
Device device = Device.fromName(deviceName, engine);
logger.info("Loading model {} from: {}", workflowName, modelUrl);
ModelInfo modelInfo = new ModelInfo(workflowName, modelUrl, version, engineName, config.getJobQueueSize(), config.getMaxIdleTime(), config.getMaxBatchDelay(), config.getBatchSize());
Workflow wf = new Workflow(modelInfo);
modelManager.registerWorkflow(wf).thenApply(p -> modelManager.scaleWorkers(wf, device, 1, -1)).thenAccept(p -> runJob(modelManager, ctx, p, input));
return;
}
if (HttpMethod.OPTIONS.equals(req.method())) {
NettyUtils.sendJsonResponse(ctx, "{}");
return;
}
runJob(modelManager, ctx, workflow, input);
}
use of ai.djl.Device in project djl by deepjavalibrary.
the class EvaluatorTrainingListener method updateEvaluators.
private void updateEvaluators(Trainer trainer, BatchData batchData, String[] accumulators) {
for (Evaluator evaluator : trainer.getEvaluators()) {
for (Device device : batchData.getLabels().keySet()) {
NDList labels = batchData.getLabels().get(device);
NDList predictions = batchData.getPredictions().get(device);
for (String accumulator : accumulators) {
evaluator.updateAccumulator(accumulator, labels, predictions);
}
}
}
}
Aggregations