Search in sources :

Example 1 with TranslatorFactory

use of ai.djl.translate.TranslatorFactory in project djl by deepjavalibrary.

the class BaseModelLoader method getTranslatorFactory.

protected TranslatorFactory getTranslatorFactory(Criteria<?, ?> criteria, Map<String, Object> arguments) {
    TranslatorFactory factory = criteria.getTranslatorFactory();
    if (factory != null) {
        return factory;
    }
    String factoryClass = (String) arguments.get("translatorFactory");
    if (factoryClass != null) {
        ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
        factory = ClassLoaderUtils.initClass(cl, factoryClass);
    }
    return factory;
}
Also used : DefaultTranslatorFactory(ai.djl.translate.DefaultTranslatorFactory) TranslatorFactory(ai.djl.translate.TranslatorFactory)

Example 2 with TranslatorFactory

use of ai.djl.translate.TranslatorFactory in project djl by deepjavalibrary.

the class BaseModelLoader method loadModel.

/**
 * {@inheritDoc}
 */
@Override
@SuppressWarnings("unchecked")
public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
    Artifact artifact = mrl.match(criteria.getFilters());
    if (artifact == null) {
        throw new ModelNotFoundException("No matching filter found");
    }
    Progress progress = criteria.getProgress();
    Map<String, Object> arguments = artifact.getArguments(criteria.getArguments());
    Map<String, String> options = artifact.getOptions(criteria.getOptions());
    try {
        TranslatorFactory factory = getTranslatorFactory(criteria, arguments);
        Class<I> input = criteria.getInputClass();
        Class<O> output = criteria.getOutputClass();
        if (factory == null || !factory.isSupported(input, output)) {
            factory = defaultFactory;
            if (!factory.isSupported(input, output)) {
                throw new ModelNotFoundException(getFactoryLookupErrorMessage(factory));
            }
        }
        mrl.prepare(artifact, progress);
        if (progress != null) {
            progress.reset("Loading", 2);
            progress.update(1);
        }
        Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
        Path modelDir = Files.isRegularFile(modelPath) ? modelPath.getParent() : modelPath;
        if (modelDir == null) {
            throw new AssertionError("Directory should not be null.");
        }
        loadServingProperties(modelDir, arguments, options);
        Application application = criteria.getApplication();
        if (application != Application.UNDEFINED) {
            arguments.put("application", application.getPath());
        }
        String engine = criteria.getEngine();
        if (engine == null) {
            // get engine from serving.properties
            engine = (String) arguments.get("engine");
        }
        // Otherwise if none of them is specified or model zoo is null, go to default engine.
        if (engine == null) {
            ModelZoo modelZoo = ModelZoo.getModelZoo(mrl.getGroupId());
            if (modelZoo != null) {
                String defaultEngine = Engine.getDefaultEngineName();
                for (String supportedEngine : modelZoo.getSupportedEngines()) {
                    if (supportedEngine.equals(defaultEngine)) {
                        engine = supportedEngine;
                        break;
                    } else if (Engine.hasEngine(supportedEngine)) {
                        engine = supportedEngine;
                    }
                }
                if (engine == null) {
                    throw new ModelNotFoundException("No supported engine available for model zoo: " + modelZoo.getGroupId());
                }
            }
        }
        if (engine != null && !Engine.hasEngine(engine)) {
            throw new ModelNotFoundException(engine + " is not supported");
        }
        String modelName = criteria.getModelName();
        if (modelName == null) {
            modelName = artifact.getName();
        }
        Model model = createModel(modelDir, modelName, criteria.getDevice(), criteria.getBlock(), arguments, engine);
        model.load(modelPath, null, options);
        Translator<I, O> translator = (Translator<I, O>) factory.newInstance(input, output, model, arguments);
        return new ZooModel<>(model, translator);
    } catch (TranslateException e) {
        throw new ModelNotFoundException("No matching translator found", e);
    } finally {
        if (progress != null) {
            progress.end();
        }
    }
}
Also used : Path(java.nio.file.Path) Progress(ai.djl.util.Progress) DefaultTranslatorFactory(ai.djl.translate.DefaultTranslatorFactory) TranslatorFactory(ai.djl.translate.TranslatorFactory) TranslateException(ai.djl.translate.TranslateException) Artifact(ai.djl.repository.Artifact) Translator(ai.djl.translate.Translator) Model(ai.djl.Model) Application(ai.djl.Application)

Aggregations

DefaultTranslatorFactory (ai.djl.translate.DefaultTranslatorFactory)2 TranslatorFactory (ai.djl.translate.TranslatorFactory)2 Application (ai.djl.Application)1 Model (ai.djl.Model)1 Artifact (ai.djl.repository.Artifact)1 TranslateException (ai.djl.translate.TranslateException)1 Translator (ai.djl.translate.Translator)1 Progress (ai.djl.util.Progress)1 Path (java.nio.file.Path)1