Search in sources :

Example 1 with Application

use of ai.djl.Application in project djl by deepjavalibrary.

the class Criteria method loadModel.

/**
 * Load the {@link ZooModel} that matches this criteria.
 *
 * @return the model that matches the criteria
 * @throws IOException for various exceptions loading data from the repository
 * @throws ModelNotFoundException if no model with the specified criteria is found
 * @throws MalformedModelException if the model data is malformed
 */
public ZooModel<I, O> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
    Logger logger = LoggerFactory.getLogger(ModelZoo.class);
    logger.debug("Loading model with {}", this);
    List<ModelZoo> list = new ArrayList<>();
    if (modelZoo != null) {
        logger.debug("Searching model in specified model zoo: {}", modelZoo.getGroupId());
        if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
            throw new ModelNotFoundException("groupId conflict with ModelZoo criteria." + modelZoo.getGroupId() + " v.s. " + groupId);
        }
        Set<String> supportedEngine = modelZoo.getSupportedEngines();
        if (engine != null && !supportedEngine.contains(engine)) {
            throw new ModelNotFoundException("ModelZoo doesn't support specified engine: " + engine);
        }
        list.add(modelZoo);
    } else {
        for (ModelZoo zoo : ModelZoo.listModelZoo()) {
            if (groupId != null && !zoo.getGroupId().equals(groupId)) {
                // filter out ModelZoo by groupId
                logger.debug("Ignore ModelZoo {} by groupId: {}", zoo.getGroupId(), groupId);
                continue;
            }
            Set<String> supportedEngine = zoo.getSupportedEngines();
            if (engine != null && !supportedEngine.contains(engine)) {
                logger.debug("Ignore ModelZoo {} by engine: {}", zoo.getGroupId(), engine);
                continue;
            }
            list.add(zoo);
        }
    }
    Exception lastException = null;
    for (ModelZoo zoo : list) {
        String loaderGroupId = zoo.getGroupId();
        for (ModelLoader loader : zoo.getModelLoaders()) {
            Application app = loader.getApplication();
            String loaderArtifactId = loader.getArtifactId();
            logger.debug("Checking ModelLoader: {}", loader);
            if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
                // filter out by model loader artifactId
                logger.debug("artifactId mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
                continue;
            }
            if (application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(application)) {
                // filter out ModelLoader by application
                logger.debug("application mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
                continue;
            }
            try {
                return loader.loadModel(this);
            } catch (ModelNotFoundException e) {
                lastException = e;
                logger.trace("", e);
                logger.debug("{} for ModelLoader: {}:{}", e.getMessage(), loaderGroupId, loaderArtifactId);
            }
        }
    }
    throw new ModelNotFoundException("No matching model with specified Input/Output type found.", lastException);
}
Also used : ArrayList(java.util.ArrayList) Logger(org.slf4j.Logger) Application(ai.djl.Application) MalformedURLException(java.net.MalformedURLException) IOException(java.io.IOException) MalformedModelException(ai.djl.MalformedModelException)

Example 2 with Application

use of ai.djl.Application in project djl by deepjavalibrary.

the class ModelZoo method listModels.

/**
 * Returns the available {@link Application} and their model artifact metadata.
 *
 * @param criteria the requirements for the model
 * @return the available {@link Application} and their model artifact metadata
 * @throws IOException if failed to download to repository metadata
 * @throws ModelNotFoundException if failed to parse repository metadata
 */
public static Map<Application, List<Artifact>> listModels(Criteria<?, ?> criteria) throws IOException, ModelNotFoundException {
    String artifactId = criteria.getArtifactId();
    ModelZoo modelZoo = criteria.getModelZoo();
    String groupId = criteria.getGroupId();
    String engine = criteria.getEngine();
    Application application = criteria.getApplication();
    @SuppressWarnings("PMD.UseConcurrentHashMap") Map<Application, List<Artifact>> models = new TreeMap<>(Comparator.comparing(Application::getPath));
    for (ModelZoo zoo : listModelZoo()) {
        if (modelZoo != null) {
            if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
                continue;
            }
            Set<String> supportedEngine = modelZoo.getSupportedEngines();
            if (engine != null && !supportedEngine.contains(engine)) {
                continue;
            }
        }
        List<ModelLoader> list = zoo.getModelLoaders();
        for (ModelLoader loader : list) {
            Application app = loader.getApplication();
            String loaderArtifactId = loader.getArtifactId();
            if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
                // filter out by model loader artifactId
                continue;
            }
            if (application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(application)) {
                // filter out ModelLoader by application
                continue;
            }
            final List<Artifact> artifacts = loader.listModels();
            models.compute(app, (key, val) -> {
                if (val == null) {
                    val = new ArrayList<>();
                }
                val.addAll(artifacts);
                return val;
            });
        }
    }
    return models;
}
Also used : TreeMap(java.util.TreeMap) Artifact(ai.djl.repository.Artifact) ArrayList(java.util.ArrayList) List(java.util.List) Application(ai.djl.Application)

Example 3 with Application

use of ai.djl.Application in project djl by deepjavalibrary.

the class LocalRepository method getResources.

/**
 * {@inheritDoc}
 */
@Override
public List<MRL> getResources() {
    List<MRL> list = new ArrayList<>();
    try {
        Files.walk(path).forEach(f -> {
            if (f.endsWith("metadata.json") && Files.isRegularFile(f)) {
                Path relative = path.relativize(f);
                String type = relative.getName(0).toString();
                try (Reader reader = Files.newBufferedReader(f)) {
                    Metadata metadata = JsonUtils.GSON.fromJson(reader, Metadata.class);
                    Application application = metadata.getApplication();
                    String groupId = metadata.getGroupId();
                    String artifactId = metadata.getArtifactId();
                    if ("dataset".equals(type)) {
                        list.add(dataset(application, groupId, artifactId));
                    } else if ("model".equals(type)) {
                        list.add(model(application, groupId, artifactId));
                    }
                } catch (IOException e) {
                    logger.warn("Failed to read metadata.json", e);
                }
            }
        });
    } catch (IOException e) {
        logger.warn("", e);
    }
    return list;
}
Also used : Path(java.nio.file.Path) ArrayList(java.util.ArrayList) Reader(java.io.Reader) IOException(java.io.IOException) Application(ai.djl.Application)

Example 4 with Application

use of ai.djl.Application in project djl by deepjavalibrary.

the class ListModelsTest method testListModelsWithApplication.

@Test
public void testListModelsWithApplication() throws ModelException, IOException {
    Path path = Paths.get("../model-zoo/src/test/resources/mlrepo");
    String repoUrl = path.toRealPath().toAbsolutePath().toUri().toURL().toExternalForm();
    System.setProperty("ai.djl.repository.zoo.location", "src/test/resources," + repoUrl);
    Criteria<?, ?> criteria = Criteria.builder().optApplication(NLP.ANY).build();
    Map<Application, List<Artifact>> models = ModelZoo.listModels(criteria);
    for (Application application : models.keySet()) {
        Assert.assertTrue(application.matches(NLP.ANY) || application.matches(Application.UNDEFINED));
        Assert.assertFalse(application.matches(CV.ANY));
    }
}
Also used : Path(java.nio.file.Path) List(java.util.List) Application(ai.djl.Application) Test(org.testng.annotations.Test)

Example 5 with Application

use of ai.djl.Application in project djl by deepjavalibrary.

the class ListModelsTest method testListModels.

@Test
public void testListModels() throws ModelException, IOException {
    Path path = Paths.get("../model-zoo/src/test/resources/mlrepo");
    String repoUrl = path.toRealPath().toAbsolutePath().toUri().toURL().toExternalForm();
    System.setProperty("ai.djl.repository.zoo.location", "src/test/resources," + repoUrl);
    Map<Application, List<Artifact>> models = ModelZoo.listModels();
    System.out.println(Arrays.toString(models.keySet().toArray()));
    List<Artifact> artifacts = models.get(Application.UNDEFINED);
    Assert.assertFalse(artifacts.isEmpty());
}
Also used : Path(java.nio.file.Path) List(java.util.List) Application(ai.djl.Application) Artifact(ai.djl.repository.Artifact) Test(org.testng.annotations.Test)

Aggregations

Application (ai.djl.Application)9 Path (java.nio.file.Path)4 List (java.util.List)4 Artifact (ai.djl.repository.Artifact)3 ArrayList (java.util.ArrayList)3 Test (org.testng.annotations.Test)3 IOException (java.io.IOException)2 MalformedModelException (ai.djl.MalformedModelException)1 Model (ai.djl.Model)1 Classifications (ai.djl.modality.Classifications)1 DefaultTranslatorFactory (ai.djl.translate.DefaultTranslatorFactory)1 TranslateException (ai.djl.translate.TranslateException)1 Translator (ai.djl.translate.Translator)1 TranslatorFactory (ai.djl.translate.TranslatorFactory)1 Progress (ai.djl.util.Progress)1 Reader (java.io.Reader)1 MalformedURLException (java.net.MalformedURLException)1 TreeMap (java.util.TreeMap)1 Logger (org.slf4j.Logger)1