use of ai.djl.Application in project djl by deepjavalibrary.
the class Criteria method loadModel.
/**
* Load the {@link ZooModel} that matches this criteria.
*
* @return the model that matches the criteria
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
public ZooModel<I, O> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
Logger logger = LoggerFactory.getLogger(ModelZoo.class);
logger.debug("Loading model with {}", this);
List<ModelZoo> list = new ArrayList<>();
if (modelZoo != null) {
logger.debug("Searching model in specified model zoo: {}", modelZoo.getGroupId());
if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
throw new ModelNotFoundException("groupId conflict with ModelZoo criteria." + modelZoo.getGroupId() + " v.s. " + groupId);
}
Set<String> supportedEngine = modelZoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
throw new ModelNotFoundException("ModelZoo doesn't support specified engine: " + engine);
}
list.add(modelZoo);
} else {
for (ModelZoo zoo : ModelZoo.listModelZoo()) {
if (groupId != null && !zoo.getGroupId().equals(groupId)) {
// filter out ModelZoo by groupId
logger.debug("Ignore ModelZoo {} by groupId: {}", zoo.getGroupId(), groupId);
continue;
}
Set<String> supportedEngine = zoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
logger.debug("Ignore ModelZoo {} by engine: {}", zoo.getGroupId(), engine);
continue;
}
list.add(zoo);
}
}
Exception lastException = null;
for (ModelZoo zoo : list) {
String loaderGroupId = zoo.getGroupId();
for (ModelLoader loader : zoo.getModelLoaders()) {
Application app = loader.getApplication();
String loaderArtifactId = loader.getArtifactId();
logger.debug("Checking ModelLoader: {}", loader);
if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
// filter out by model loader artifactId
logger.debug("artifactId mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
continue;
}
if (application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(application)) {
// filter out ModelLoader by application
logger.debug("application mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
continue;
}
try {
return loader.loadModel(this);
} catch (ModelNotFoundException e) {
lastException = e;
logger.trace("", e);
logger.debug("{} for ModelLoader: {}:{}", e.getMessage(), loaderGroupId, loaderArtifactId);
}
}
}
throw new ModelNotFoundException("No matching model with specified Input/Output type found.", lastException);
}
use of ai.djl.Application in project djl by deepjavalibrary.
the class ModelZoo method listModels.
/**
* Returns the available {@link Application} and their model artifact metadata.
*
* @param criteria the requirements for the model
* @return the available {@link Application} and their model artifact metadata
* @throws IOException if failed to download to repository metadata
* @throws ModelNotFoundException if failed to parse repository metadata
*/
public static Map<Application, List<Artifact>> listModels(Criteria<?, ?> criteria) throws IOException, ModelNotFoundException {
String artifactId = criteria.getArtifactId();
ModelZoo modelZoo = criteria.getModelZoo();
String groupId = criteria.getGroupId();
String engine = criteria.getEngine();
Application application = criteria.getApplication();
@SuppressWarnings("PMD.UseConcurrentHashMap") Map<Application, List<Artifact>> models = new TreeMap<>(Comparator.comparing(Application::getPath));
for (ModelZoo zoo : listModelZoo()) {
if (modelZoo != null) {
if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
continue;
}
Set<String> supportedEngine = modelZoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
continue;
}
}
List<ModelLoader> list = zoo.getModelLoaders();
for (ModelLoader loader : list) {
Application app = loader.getApplication();
String loaderArtifactId = loader.getArtifactId();
if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
// filter out by model loader artifactId
continue;
}
if (application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(application)) {
// filter out ModelLoader by application
continue;
}
final List<Artifact> artifacts = loader.listModels();
models.compute(app, (key, val) -> {
if (val == null) {
val = new ArrayList<>();
}
val.addAll(artifacts);
return val;
});
}
}
return models;
}
use of ai.djl.Application in project djl by deepjavalibrary.
the class LocalRepository method getResources.
/**
* {@inheritDoc}
*/
@Override
public List<MRL> getResources() {
List<MRL> list = new ArrayList<>();
try {
Files.walk(path).forEach(f -> {
if (f.endsWith("metadata.json") && Files.isRegularFile(f)) {
Path relative = path.relativize(f);
String type = relative.getName(0).toString();
try (Reader reader = Files.newBufferedReader(f)) {
Metadata metadata = JsonUtils.GSON.fromJson(reader, Metadata.class);
Application application = metadata.getApplication();
String groupId = metadata.getGroupId();
String artifactId = metadata.getArtifactId();
if ("dataset".equals(type)) {
list.add(dataset(application, groupId, artifactId));
} else if ("model".equals(type)) {
list.add(model(application, groupId, artifactId));
}
} catch (IOException e) {
logger.warn("Failed to read metadata.json", e);
}
}
});
} catch (IOException e) {
logger.warn("", e);
}
return list;
}
use of ai.djl.Application in project djl by deepjavalibrary.
the class ListModelsTest method testListModelsWithApplication.
@Test
public void testListModelsWithApplication() throws ModelException, IOException {
Path path = Paths.get("../model-zoo/src/test/resources/mlrepo");
String repoUrl = path.toRealPath().toAbsolutePath().toUri().toURL().toExternalForm();
System.setProperty("ai.djl.repository.zoo.location", "src/test/resources," + repoUrl);
Criteria<?, ?> criteria = Criteria.builder().optApplication(NLP.ANY).build();
Map<Application, List<Artifact>> models = ModelZoo.listModels(criteria);
for (Application application : models.keySet()) {
Assert.assertTrue(application.matches(NLP.ANY) || application.matches(Application.UNDEFINED));
Assert.assertFalse(application.matches(CV.ANY));
}
}
use of ai.djl.Application in project djl by deepjavalibrary.
the class ListModelsTest method testListModels.
@Test
public void testListModels() throws ModelException, IOException {
Path path = Paths.get("../model-zoo/src/test/resources/mlrepo");
String repoUrl = path.toRealPath().toAbsolutePath().toUri().toURL().toExternalForm();
System.setProperty("ai.djl.repository.zoo.location", "src/test/resources," + repoUrl);
Map<Application, List<Artifact>> models = ModelZoo.listModels();
System.out.println(Arrays.toString(models.keySet().toArray()));
List<Artifact> artifacts = models.get(Application.UNDEFINED);
Assert.assertFalse(artifacts.isEmpty());
}
Aggregations