use of ai.djl.translate.TranslatorFactory in project djl by deepjavalibrary.
the class BaseModelLoader method getTranslatorFactory.
protected TranslatorFactory getTranslatorFactory(Criteria<?, ?> criteria, Map<String, Object> arguments) {
TranslatorFactory factory = criteria.getTranslatorFactory();
if (factory != null) {
return factory;
}
String factoryClass = (String) arguments.get("translatorFactory");
if (factoryClass != null) {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
factory = ClassLoaderUtils.initClass(cl, factoryClass);
}
return factory;
}
use of ai.djl.translate.TranslatorFactory in project djl by deepjavalibrary.
the class BaseModelLoader method loadModel.
/**
* {@inheritDoc}
*/
@Override
@SuppressWarnings("unchecked")
public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
Artifact artifact = mrl.match(criteria.getFilters());
if (artifact == null) {
throw new ModelNotFoundException("No matching filter found");
}
Progress progress = criteria.getProgress();
Map<String, Object> arguments = artifact.getArguments(criteria.getArguments());
Map<String, String> options = artifact.getOptions(criteria.getOptions());
try {
TranslatorFactory factory = getTranslatorFactory(criteria, arguments);
Class<I> input = criteria.getInputClass();
Class<O> output = criteria.getOutputClass();
if (factory == null || !factory.isSupported(input, output)) {
factory = defaultFactory;
if (!factory.isSupported(input, output)) {
throw new ModelNotFoundException(getFactoryLookupErrorMessage(factory));
}
}
mrl.prepare(artifact, progress);
if (progress != null) {
progress.reset("Loading", 2);
progress.update(1);
}
Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
Path modelDir = Files.isRegularFile(modelPath) ? modelPath.getParent() : modelPath;
if (modelDir == null) {
throw new AssertionError("Directory should not be null.");
}
loadServingProperties(modelDir, arguments, options);
Application application = criteria.getApplication();
if (application != Application.UNDEFINED) {
arguments.put("application", application.getPath());
}
String engine = criteria.getEngine();
if (engine == null) {
// get engine from serving.properties
engine = (String) arguments.get("engine");
}
// Otherwise if none of them is specified or model zoo is null, go to default engine.
if (engine == null) {
ModelZoo modelZoo = ModelZoo.getModelZoo(mrl.getGroupId());
if (modelZoo != null) {
String defaultEngine = Engine.getDefaultEngineName();
for (String supportedEngine : modelZoo.getSupportedEngines()) {
if (supportedEngine.equals(defaultEngine)) {
engine = supportedEngine;
break;
} else if (Engine.hasEngine(supportedEngine)) {
engine = supportedEngine;
}
}
if (engine == null) {
throw new ModelNotFoundException("No supported engine available for model zoo: " + modelZoo.getGroupId());
}
}
}
if (engine != null && !Engine.hasEngine(engine)) {
throw new ModelNotFoundException(engine + " is not supported");
}
String modelName = criteria.getModelName();
if (modelName == null) {
modelName = artifact.getName();
}
Model model = createModel(modelDir, modelName, criteria.getDevice(), criteria.getBlock(), arguments, engine);
model.load(modelPath, null, options);
Translator<I, O> translator = (Translator<I, O>) factory.newInstance(input, output, model, arguments);
return new ZooModel<>(model, translator);
} catch (TranslateException e) {
throw new ModelNotFoundException("No matching translator found", e);
} finally {
if (progress != null) {
progress.end();
}
}
}
Aggregations