Search in sources :

Example 1 with Device

use of ai.djl.Device in project djl-demo by deepjavalibrary.

the class CanaryTest method main.

public static void main(String[] args) throws IOException, ModelException, TranslateException {
    logger.info("");
    logger.info("----------Environment Variables----------");
    System.getenv().forEach((k, v) -> logger.info(k + ": " + v));
    logger.info("");
    logger.info("----------Default Engine----------");
    Engine.debugEnvironment();
    logger.info("");
    logger.info("----------Device information----------");
    int gpuCount = CudaUtils.getGpuCount();
    logger.info("GPU Count: {}", gpuCount);
    if (gpuCount > 0) {
        logger.info("CUDA: {}", CudaUtils.getCudaVersionString());
        logger.info("ARCH: {}", CudaUtils.getComputeCapability(0));
    }
    String djlEngine = System.getenv("DJL_ENGINE");
    if (djlEngine == null) {
        djlEngine = "mxnet-native-auto";
    }
    Device device = NDManager.newBaseManager().getDevice();
    if (djlEngine.contains("-native-cu") && !device.isGpu()) {
        throw new AssertionError("Expecting load engine on GPU.");
    } else if (djlEngine.startsWith("tensorrt")) {
        testTensorrt();
        return;
    } else if (djlEngine.startsWith("onnxruntime")) {
        testOnnxRuntime();
        return;
    } else if (djlEngine.startsWith("xgboost")) {
        testXgboost();
        return;
    } else if (djlEngine.startsWith("tflite")) {
        testTflite();
        return;
    } else if (djlEngine.startsWith("python")) {
        testPython();
        return;
    } else if (djlEngine.startsWith("dlr")) {
        testDlr();
        // similar to DLR, fastText and SentencePiece only support Mac and Ubuntu 16.04+
        testFastText();
        testSentencePiece();
        return;
    } else if (djlEngine.startsWith("paddle")) {
        testPaddle();
        return;
    }
    logger.info("");
    logger.info("----------Test inference----------");
    String url = "https://resources.djl.ai/images/dog_bike_car.jpg";
    Image img = ImageFactory.getInstance().fromUrl(url);
    String backbone = "resnet50";
    Map<String, String> options = null;
    if ("TensorFlow".equals(Engine.getInstance().getEngineName())) {
        backbone = "mobilenet_v2";
        options = new ConcurrentHashMap<>();
        options.put("Tags", "");
    }
    Criteria<Image, DetectedObjects> criteria = Criteria.builder().optApplication(Application.CV.OBJECT_DETECTION).setTypes(Image.class, DetectedObjects.class).optFilter("backbone", backbone).optOptions(options).build();
    try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
            DetectedObjects detection = predictor.predict(img);
            logger.info("{}", detection);
        }
    }
}
Also used : Device(ai.djl.Device) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) Image(ai.djl.modality.cv.Image)

Example 2 with Device

use of ai.djl.Device in project djl-serving by deepjavalibrary.

the class PyProcess method predict.

Output predict(Input inputs) throws TranslateException {
    try {
        if (inputs.getProperty("handler", null) == null) {
            String handler = pyEnv.getHandler();
            if (handler != null) {
                inputs.addProperty("handler", handler);
            }
        }
        Device device = model.getNDManager().getDevice();
        inputs.addProperty("device_id", String.valueOf(device.getDeviceId()));
        return connection.send(inputs);
    } catch (ExecutionException | InterruptedException | TimeoutException e) {
        throw new TranslateException(e);
    }
}
Also used : TranslateException(ai.djl.translate.TranslateException) Device(ai.djl.Device) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException)

Example 3 with Device

use of ai.djl.Device in project djl-serving by deepjavalibrary.

the class ModelServer method initWorkflows.

private void initWorkflows() throws IOException, URISyntaxException, ModelNotFoundException, BadWorkflowException, MalformedModelException {
    Set<String> startupWorkflows = ModelManager.getInstance().getStartupWorkflows();
    String loadWorkflows = configManager.getLoadWorkflows();
    if (loadWorkflows == null || loadWorkflows.isEmpty()) {
        return;
    }
    ModelManager modelManager = ModelManager.getInstance();
    String[] urls = loadWorkflows.split("[, ]+");
    for (String url : urls) {
        logger.info("Initializing workflow: {}", url);
        Matcher matcher = MODEL_STORE_PATTERN.matcher(url);
        if (!matcher.matches()) {
            throw new AssertionError("Invalid model store url: " + url);
        }
        String endpoint = matcher.group(2);
        String workflowUrlString = matcher.group(3);
        Device[] devices = { null };
        String workflowName;
        if (endpoint != null) {
            String[] tokens = endpoint.split(":", -1);
            workflowName = tokens[0];
            if (tokens.length > 1) {
                devices = parseDevices(tokens[1], Engine.getInstance());
            }
        } else {
            workflowName = ModelInfo.inferModelNameFromUrl(workflowUrlString);
        }
        URL workflowUrl = new URL(workflowUrlString);
        Workflow workflow = WorkflowDefinition.parse(workflowUrl.toURI(), workflowUrl.openStream()).toWorkflow();
        Device[] finalDevices = devices;
        CompletableFuture<Void> f = modelManager.registerWorkflow(workflow).thenAccept(v -> {
            for (Device device : finalDevices) {
                modelManager.scaleWorkers(workflow, device, 1, -1);
            }
        }).exceptionally(t -> {
            logger.error("Failed register workflow", t);
            // response (health check)
            try {
                Thread.sleep(3000);
            } catch (InterruptedException ignore) {
            // ignore
            }
            stop();
            return null;
        });
        if (configManager.waitModelLoading()) {
            f.join();
        }
        startupWorkflows.add(workflowName);
    }
}
Also used : Unit(ai.djl.metric.Unit) Arrays(java.util.Arrays) Connector(ai.djl.serving.util.Connector) URL(java.net.URL) URISyntaxException(java.net.URISyntaxException) LoggerFactory(org.slf4j.LoggerFactory) Device(ai.djl.Device) Workflow(ai.djl.serving.workflow.Workflow) DefaultParser(org.apache.commons.cli.DefaultParser) ModelInfo(ai.djl.serving.wlm.ModelInfo) GeneralSecurityException(java.security.GeneralSecurityException) Matcher(java.util.regex.Matcher) ModelManager(ai.djl.serving.models.ModelManager) NeuronUtils(ai.djl.serving.util.NeuronUtils) BadWorkflowException(ai.djl.serving.workflow.BadWorkflowException) Path(java.nio.file.Path) Slf4JLoggerFactory(io.netty.util.internal.logging.Slf4JLoggerFactory) Artifact(ai.djl.repository.Artifact) Set(java.util.Set) ServerStartupException(ai.djl.serving.http.ServerStartupException) Collectors(java.util.stream.Collectors) ServerChannel(io.netty.channel.ServerChannel) Objects(java.util.Objects) List(java.util.List) ParseException(org.apache.commons.cli.ParseException) MRL(ai.djl.repository.MRL) Metric(ai.djl.metric.Metric) Pattern(java.util.regex.Pattern) IntStream(java.util.stream.IntStream) ModelNotFoundException(ai.djl.repository.zoo.ModelNotFoundException) ChannelOption(io.netty.channel.ChannelOption) DependencyManager(ai.djl.serving.plugins.DependencyManager) Options(org.apache.commons.cli.Options) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) CompletableFuture(java.util.concurrent.CompletableFuture) HelpFormatter(org.apache.commons.cli.HelpFormatter) ArrayList(java.util.ArrayList) WorkflowDefinition(ai.djl.serving.workflow.WorkflowDefinition) ChannelFutureListener(io.netty.channel.ChannelFutureListener) CommandLine(org.apache.commons.cli.CommandLine) Engine(ai.djl.engine.Engine) ServerGroups(ai.djl.serving.util.ServerGroups) EventLoopGroup(io.netty.channel.EventLoopGroup) Properties(java.util.Properties) Logger(org.slf4j.Logger) SslContext(io.netty.handler.ssl.SslContext) MalformedURLException(java.net.MalformedURLException) Files(java.nio.file.Files) IOException(java.io.IOException) MalformedModelException(ai.djl.MalformedModelException) File(java.io.File) ChannelFuture(io.netty.channel.ChannelFuture) ExecutionException(java.util.concurrent.ExecutionException) Repository(ai.djl.repository.Repository) FolderScanPluginManager(ai.djl.serving.plugins.FolderScanPluginManager) ConfigManager(ai.djl.serving.util.ConfigManager) FilenameUtils(ai.djl.repository.FilenameUtils) CudaUtils(ai.djl.util.cuda.CudaUtils) ServerBootstrap(io.netty.bootstrap.ServerBootstrap) InternalLoggerFactory(io.netty.util.internal.logging.InternalLoggerFactory) Dimension(ai.djl.metric.Dimension) InputStream(java.io.InputStream) Matcher(java.util.regex.Matcher) Device(ai.djl.Device) Workflow(ai.djl.serving.workflow.Workflow) ModelManager(ai.djl.serving.models.ModelManager) URL(java.net.URL)

Example 4 with Device

use of ai.djl.Device in project djl-serving by deepjavalibrary.

the class InferenceRequestHandler method predict.

private void predict(ChannelHandlerContext ctx, FullHttpRequest req, Input input, String workflowName, String version) throws ModelNotFoundException {
    ModelManager modelManager = ModelManager.getInstance();
    ConfigManager config = ConfigManager.getInstance();
    Workflow workflow = modelManager.getWorkflow(workflowName, version, true);
    if (workflow == null) {
        String regex = config.getModelUrlPattern();
        if (regex == null) {
            throw new ModelNotFoundException("Model or workflow not found: " + workflowName);
        }
        String modelUrl = input.getProperty("model_url", null);
        if (modelUrl == null) {
            modelUrl = input.getAsString("model_url");
            if (modelUrl == null) {
                throw new ModelNotFoundException("Parameter model_url is required.");
            }
            if (!modelUrl.matches(regex)) {
                throw new ModelNotFoundException("Permission denied: " + modelUrl);
            }
        }
        String engineName = input.getProperty("engine_name", null);
        String deviceName = input.getProperty("device", "-1");
        Engine engine = engineName != null ? Engine.getEngine(engineName) : Engine.getInstance();
        Device device = Device.fromName(deviceName, engine);
        logger.info("Loading model {} from: {}", workflowName, modelUrl);
        ModelInfo modelInfo = new ModelInfo(workflowName, modelUrl, version, engineName, config.getJobQueueSize(), config.getMaxIdleTime(), config.getMaxBatchDelay(), config.getBatchSize());
        Workflow wf = new Workflow(modelInfo);
        modelManager.registerWorkflow(wf).thenApply(p -> modelManager.scaleWorkers(wf, device, 1, -1)).thenAccept(p -> runJob(modelManager, ctx, p, input));
        return;
    }
    if (HttpMethod.OPTIONS.equals(req.method())) {
        NettyUtils.sendJsonResponse(ctx, "{}");
        return;
    }
    runJob(modelManager, ctx, workflow, input);
}
Also used : ModelNotFoundException(ai.djl.repository.zoo.ModelNotFoundException) HttpVersion(io.netty.handler.codec.http.HttpVersion) Output(ai.djl.modality.Output) LoggerFactory(org.slf4j.LoggerFactory) Device(ai.djl.Device) ModelException(ai.djl.ModelException) Workflow(ai.djl.serving.workflow.Workflow) Input(ai.djl.modality.Input) ChannelHandlerContext(io.netty.channel.ChannelHandlerContext) NettyUtils(ai.djl.serving.util.NettyUtils) ModelInfo(ai.djl.serving.wlm.ModelInfo) ModelManager(ai.djl.serving.models.ModelManager) TranslateException(ai.djl.translate.TranslateException) Map(java.util.Map) Engine(ai.djl.engine.Engine) WlmException(ai.djl.serving.wlm.util.WlmException) Logger(org.slf4j.Logger) BytesSupplier(ai.djl.ndarray.BytesSupplier) HttpMethod(io.netty.handler.codec.http.HttpMethod) Set(java.util.Set) HttpResponseStatus(io.netty.handler.codec.http.HttpResponseStatus) FullHttpRequest(io.netty.handler.codec.http.FullHttpRequest) FullHttpResponse(io.netty.handler.codec.http.FullHttpResponse) ConfigManager(ai.djl.serving.util.ConfigManager) DefaultFullHttpResponse(io.netty.handler.codec.http.DefaultFullHttpResponse) QueryStringDecoder(io.netty.handler.codec.http.QueryStringDecoder) Metric(ai.djl.metric.Metric) Pattern(java.util.regex.Pattern) ModelInfo(ai.djl.serving.wlm.ModelInfo) Device(ai.djl.Device) ModelNotFoundException(ai.djl.repository.zoo.ModelNotFoundException) Workflow(ai.djl.serving.workflow.Workflow) ModelManager(ai.djl.serving.models.ModelManager) ConfigManager(ai.djl.serving.util.ConfigManager) Engine(ai.djl.engine.Engine)

Example 5 with Device

use of ai.djl.Device in project djl by deepjavalibrary.

the class EvaluatorTrainingListener method updateEvaluators.

private void updateEvaluators(Trainer trainer, BatchData batchData, String[] accumulators) {
    for (Evaluator evaluator : trainer.getEvaluators()) {
        for (Device device : batchData.getLabels().keySet()) {
            NDList labels = batchData.getLabels().get(device);
            NDList predictions = batchData.getPredictions().get(device);
            for (String accumulator : accumulators) {
                evaluator.updateAccumulator(accumulator, labels, predictions);
            }
        }
    }
}
Also used : Device(ai.djl.Device) NDList(ai.djl.ndarray.NDList) Evaluator(ai.djl.training.evaluator.Evaluator)

Aggregations

Device (ai.djl.Device)43 NDArray (ai.djl.ndarray.NDArray)24 Shape (ai.djl.ndarray.types.Shape)16 Test (org.testng.annotations.Test)13 NDManager (ai.djl.ndarray.NDManager)12 Model (ai.djl.Model)9 Engine (ai.djl.engine.Engine)8 NDList (ai.djl.ndarray.NDList)8 Block (ai.djl.nn.Block)8 DefaultTrainingConfig (ai.djl.training.DefaultTrainingConfig)8 Trainer (ai.djl.training.Trainer)8 TrainingConfig (ai.djl.training.TrainingConfig)8 Optimizer (ai.djl.training.optimizer.Optimizer)8 ModelInfo (ai.djl.serving.wlm.ModelInfo)6 Workflow (ai.djl.serving.workflow.Workflow)6 IOException (java.io.IOException)6 ModelNotFoundException (ai.djl.repository.zoo.ModelNotFoundException)5 ModelManager (ai.djl.serving.models.ModelManager)5 ConfigManager (ai.djl.serving.util.ConfigManager)5 Pattern (java.util.regex.Pattern)5