Search in sources :

Example 21 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class PikachuDetection method prepare.

/**
 * {@inheritDoc}
 */
@Override
public void prepare(Progress progress) throws IOException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path root = mrl.getRepository().getResourceDirectory(artifact);
    Path usagePath;
    switch(usage) {
        case TRAIN:
            usagePath = Paths.get("train");
            break;
        case TEST:
            usagePath = Paths.get("test");
            break;
        case VALIDATION:
        default:
            throw new UnsupportedOperationException("Validation data not available.");
    }
    usagePath = root.resolve(usagePath);
    Path indexFile = usagePath.resolve("index.file");
    try (Reader reader = Files.newBufferedReader(indexFile)) {
        Type mapType = new TypeToken<Map<String, List<Float>>>() {
        }.getType();
        Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
        for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
            String imgName = entry.getKey();
            imagePaths.add(usagePath.resolve(imgName));
            List<Float> label = entry.getValue();
            long objectClass = label.get(4).longValue();
            Rectangle objectLocation = new Rectangle(new Point(label.get(5), label.get(6)), label.get(7), label.get(8));
            labels.add(objectClass, objectLocation);
        }
    }
    prepared = true;
}
Also used : Path(java.nio.file.Path) Rectangle(ai.djl.modality.cv.output.Rectangle) Reader(java.io.Reader) Point(ai.djl.modality.cv.output.Point) Artifact(ai.djl.repository.Artifact) Type(java.lang.reflect.Type) ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList) List(java.util.List) Map(java.util.Map)

Example 22 with Artifact

use of ai.djl.repository.Artifact in project djl-serving by deepjavalibrary.

the class ModelServer method inferEngineFromUrl.

private String inferEngineFromUrl(String modelUrl) {
    try {
        Repository repository = Repository.newInstance("modelStore", modelUrl);
        List<MRL> mrls = repository.getResources();
        if (mrls.isEmpty()) {
            throw new IllegalArgumentException("Invalid model url: " + modelUrl);
        }
        Artifact artifact = mrls.get(0).getDefaultArtifact();
        repository.prepare(artifact);
        Path modelDir = repository.getResourceDirectory(artifact);
        return inferEngine(modelDir, artifact.getName());
    } catch (IOException e) {
        logger.warn("Failed to extract model: " + modelUrl, e);
        return null;
    }
}
Also used : Path(java.nio.file.Path) Repository(ai.djl.repository.Repository) MRL(ai.djl.repository.MRL) IOException(java.io.IOException) Artifact(ai.djl.repository.Artifact)

Example 23 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class RepositoryTest method testSimpleRepositoryDir.

@Test
public void testSimpleRepositoryDir() throws IOException {
    Repository repo = Repository.newInstance("archive", "build/models/archive?artifact_id=resnet&model_name=resnet18");
    List<MRL> resources = repo.getResources();
    Assert.assertEquals(resources.size(), 1);
    Metadata metadata = repo.locate(resources.get(0));
    Assert.assertEquals(metadata.getApplication(), Application.UNDEFINED);
    Assert.assertEquals(metadata.getGroupId(), DefaultModelZoo.GROUP_ID);
    Assert.assertEquals(metadata.getArtifactId(), "resnet");
    List<Artifact> artifacts = metadata.getArtifacts();
    Assert.assertEquals(artifacts.size(), 1);
    Artifact artifact = artifacts.get(0);
    Assert.assertEquals(artifact.getName(), "resnet18");
    Map<String, Artifact.Item> files = artifact.getFiles();
    Assert.assertEquals(files.size(), 2);
}
Also used : Repository(ai.djl.repository.Repository) MRL(ai.djl.repository.MRL) Metadata(ai.djl.repository.Metadata) Artifact(ai.djl.repository.Artifact) Test(org.testng.annotations.Test)

Example 24 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class BaseModelLoader method loadModel.

/**
 * {@inheritDoc}
 */
@Override
@SuppressWarnings("unchecked")
public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
    Artifact artifact = mrl.match(criteria.getFilters());
    if (artifact == null) {
        throw new ModelNotFoundException("No matching filter found");
    }
    Progress progress = criteria.getProgress();
    Map<String, Object> arguments = artifact.getArguments(criteria.getArguments());
    Map<String, String> options = artifact.getOptions(criteria.getOptions());
    try {
        TranslatorFactory factory = getTranslatorFactory(criteria, arguments);
        Class<I> input = criteria.getInputClass();
        Class<O> output = criteria.getOutputClass();
        if (factory == null || !factory.isSupported(input, output)) {
            factory = defaultFactory;
            if (!factory.isSupported(input, output)) {
                throw new ModelNotFoundException(getFactoryLookupErrorMessage(factory));
            }
        }
        mrl.prepare(artifact, progress);
        if (progress != null) {
            progress.reset("Loading", 2);
            progress.update(1);
        }
        Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
        Path modelDir = Files.isRegularFile(modelPath) ? modelPath.getParent() : modelPath;
        if (modelDir == null) {
            throw new AssertionError("Directory should not be null.");
        }
        loadServingProperties(modelDir, arguments, options);
        Application application = criteria.getApplication();
        if (application != Application.UNDEFINED) {
            arguments.put("application", application.getPath());
        }
        String engine = criteria.getEngine();
        if (engine == null) {
            // get engine from serving.properties
            engine = (String) arguments.get("engine");
        }
        // Otherwise if none of them is specified or model zoo is null, go to default engine.
        if (engine == null) {
            ModelZoo modelZoo = ModelZoo.getModelZoo(mrl.getGroupId());
            if (modelZoo != null) {
                String defaultEngine = Engine.getDefaultEngineName();
                for (String supportedEngine : modelZoo.getSupportedEngines()) {
                    if (supportedEngine.equals(defaultEngine)) {
                        engine = supportedEngine;
                        break;
                    } else if (Engine.hasEngine(supportedEngine)) {
                        engine = supportedEngine;
                    }
                }
                if (engine == null) {
                    throw new ModelNotFoundException("No supported engine available for model zoo: " + modelZoo.getGroupId());
                }
            }
        }
        if (engine != null && !Engine.hasEngine(engine)) {
            throw new ModelNotFoundException(engine + " is not supported");
        }
        String modelName = criteria.getModelName();
        if (modelName == null) {
            modelName = artifact.getName();
        }
        Model model = createModel(modelDir, modelName, criteria.getDevice(), criteria.getBlock(), arguments, engine);
        model.load(modelPath, null, options);
        Translator<I, O> translator = (Translator<I, O>) factory.newInstance(input, output, model, arguments);
        return new ZooModel<>(model, translator);
    } catch (TranslateException e) {
        throw new ModelNotFoundException("No matching translator found", e);
    } finally {
        if (progress != null) {
            progress.end();
        }
    }
}
Also used : Path(java.nio.file.Path) Progress(ai.djl.util.Progress) DefaultTranslatorFactory(ai.djl.translate.DefaultTranslatorFactory) TranslatorFactory(ai.djl.translate.TranslatorFactory) TranslateException(ai.djl.translate.TranslateException) Artifact(ai.djl.repository.Artifact) Translator(ai.djl.translate.Translator) Model(ai.djl.Model) Application(ai.djl.Application)

Example 25 with Artifact

use of ai.djl.repository.Artifact in project djl by deepjavalibrary.

the class BananaDetection method prepare.

/**
 * {@inheritDoc}
 */
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
    if (prepared) {
        return;
    }
    Artifact artifact = mrl.getDefaultArtifact();
    mrl.prepare(artifact, progress);
    Path root = mrl.getRepository().getResourceDirectory(artifact);
    Path usagePath;
    switch(usage) {
        case TRAIN:
            usagePath = Paths.get("train");
            break;
        case TEST:
            usagePath = Paths.get("test");
            break;
        case VALIDATION:
        default:
            throw new UnsupportedOperationException("Validation data not available.");
    }
    usagePath = root.resolve(usagePath);
    Path indexFile = usagePath.resolve("index.file");
    try (Reader reader = Files.newBufferedReader(indexFile)) {
        Type mapType = new TypeToken<Map<String, List<Float>>>() {
        }.getType();
        Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
        for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
            String imgName = entry.getKey();
            imagePaths.add(usagePath.resolve(imgName));
            List<Float> label = entry.getValue();
            long objectClass = label.get(0).longValue();
            Rectangle objectLocation = new Rectangle(new Point(label.get(1), label.get(2)), label.get(3), label.get(4));
            labels.add(objectClass, objectLocation);
        }
    }
    prepared = true;
}
Also used : Path(java.nio.file.Path) Rectangle(ai.djl.modality.cv.output.Rectangle) Reader(java.io.Reader) Point(ai.djl.modality.cv.output.Point) Artifact(ai.djl.repository.Artifact) Type(java.lang.reflect.Type) ArrayList(java.util.ArrayList) PairList(ai.djl.util.PairList) List(java.util.List) Map(java.util.Map)

Aggregations

Artifact (ai.djl.repository.Artifact)33 Path (java.nio.file.Path)15 MRL (ai.djl.repository.MRL)10 Test (org.testng.annotations.Test)9 Repository (ai.djl.repository.Repository)8 Metadata (ai.djl.repository.Metadata)7 IOException (java.io.IOException)5 ArrayList (java.util.ArrayList)5 List (java.util.List)4 Application (ai.djl.Application)3 Model (ai.djl.Model)3 Rectangle (ai.djl.modality.cv.output.Rectangle)3 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)3 Point (ai.djl.modality.cv.output.Point)2 PairList (ai.djl.util.PairList)2 Progress (ai.djl.util.Progress)2 Reader (java.io.Reader)2 Type (java.lang.reflect.Type)2 Map (java.util.Map)2 FtModel (ai.djl.fasttext.FtModel)1