use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class PikachuDetection method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train");
break;
case TEST:
usagePath = Paths.get("test");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
Type mapType = new TypeToken<Map<String, List<Float>>>() {
}.getType();
Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
String imgName = entry.getKey();
imagePaths.add(usagePath.resolve(imgName));
List<Float> label = entry.getValue();
long objectClass = label.get(4).longValue();
Rectangle objectLocation = new Rectangle(new Point(label.get(5), label.get(6)), label.get(7), label.get(8));
labels.add(objectClass, objectLocation);
}
}
prepared = true;
}
use of ai.djl.repository.Artifact in project djl-serving by deepjavalibrary.
the class ModelServer method inferEngineFromUrl.
private String inferEngineFromUrl(String modelUrl) {
try {
Repository repository = Repository.newInstance("modelStore", modelUrl);
List<MRL> mrls = repository.getResources();
if (mrls.isEmpty()) {
throw new IllegalArgumentException("Invalid model url: " + modelUrl);
}
Artifact artifact = mrls.get(0).getDefaultArtifact();
repository.prepare(artifact);
Path modelDir = repository.getResourceDirectory(artifact);
return inferEngine(modelDir, artifact.getName());
} catch (IOException e) {
logger.warn("Failed to extract model: " + modelUrl, e);
return null;
}
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class RepositoryTest method testSimpleRepositoryDir.
@Test
public void testSimpleRepositoryDir() throws IOException {
Repository repo = Repository.newInstance("archive", "build/models/archive?artifact_id=resnet&model_name=resnet18");
List<MRL> resources = repo.getResources();
Assert.assertEquals(resources.size(), 1);
Metadata metadata = repo.locate(resources.get(0));
Assert.assertEquals(metadata.getApplication(), Application.UNDEFINED);
Assert.assertEquals(metadata.getGroupId(), DefaultModelZoo.GROUP_ID);
Assert.assertEquals(metadata.getArtifactId(), "resnet");
List<Artifact> artifacts = metadata.getArtifacts();
Assert.assertEquals(artifacts.size(), 1);
Artifact artifact = artifacts.get(0);
Assert.assertEquals(artifact.getName(), "resnet18");
Map<String, Artifact.Item> files = artifact.getFiles();
Assert.assertEquals(files.size(), 2);
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class BaseModelLoader method loadModel.
/**
* {@inheritDoc}
*/
@Override
@SuppressWarnings("unchecked")
public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
Artifact artifact = mrl.match(criteria.getFilters());
if (artifact == null) {
throw new ModelNotFoundException("No matching filter found");
}
Progress progress = criteria.getProgress();
Map<String, Object> arguments = artifact.getArguments(criteria.getArguments());
Map<String, String> options = artifact.getOptions(criteria.getOptions());
try {
TranslatorFactory factory = getTranslatorFactory(criteria, arguments);
Class<I> input = criteria.getInputClass();
Class<O> output = criteria.getOutputClass();
if (factory == null || !factory.isSupported(input, output)) {
factory = defaultFactory;
if (!factory.isSupported(input, output)) {
throw new ModelNotFoundException(getFactoryLookupErrorMessage(factory));
}
}
mrl.prepare(artifact, progress);
if (progress != null) {
progress.reset("Loading", 2);
progress.update(1);
}
Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
Path modelDir = Files.isRegularFile(modelPath) ? modelPath.getParent() : modelPath;
if (modelDir == null) {
throw new AssertionError("Directory should not be null.");
}
loadServingProperties(modelDir, arguments, options);
Application application = criteria.getApplication();
if (application != Application.UNDEFINED) {
arguments.put("application", application.getPath());
}
String engine = criteria.getEngine();
if (engine == null) {
// get engine from serving.properties
engine = (String) arguments.get("engine");
}
// Otherwise if none of them is specified or model zoo is null, go to default engine.
if (engine == null) {
ModelZoo modelZoo = ModelZoo.getModelZoo(mrl.getGroupId());
if (modelZoo != null) {
String defaultEngine = Engine.getDefaultEngineName();
for (String supportedEngine : modelZoo.getSupportedEngines()) {
if (supportedEngine.equals(defaultEngine)) {
engine = supportedEngine;
break;
} else if (Engine.hasEngine(supportedEngine)) {
engine = supportedEngine;
}
}
if (engine == null) {
throw new ModelNotFoundException("No supported engine available for model zoo: " + modelZoo.getGroupId());
}
}
}
if (engine != null && !Engine.hasEngine(engine)) {
throw new ModelNotFoundException(engine + " is not supported");
}
String modelName = criteria.getModelName();
if (modelName == null) {
modelName = artifact.getName();
}
Model model = createModel(modelDir, modelName, criteria.getDevice(), criteria.getBlock(), arguments, engine);
model.load(modelPath, null, options);
Translator<I, O> translator = (Translator<I, O>) factory.newInstance(input, output, model, arguments);
return new ZooModel<>(model, translator);
} catch (TranslateException e) {
throw new ModelNotFoundException("No matching translator found", e);
} finally {
if (progress != null) {
progress.end();
}
}
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class BananaDetection method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train");
break;
case TEST:
usagePath = Paths.get("test");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
Type mapType = new TypeToken<Map<String, List<Float>>>() {
}.getType();
Map<String, List<Float>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
for (Map.Entry<String, List<Float>> entry : metadata.entrySet()) {
String imgName = entry.getKey();
imagePaths.add(usagePath.resolve(imgName));
List<Float> label = entry.getValue();
long objectClass = label.get(0).longValue();
Rectangle objectLocation = new Rectangle(new Point(label.get(1), label.get(2)), label.get(3), label.get(4));
labels.add(objectClass, objectLocation);
}
}
prepared = true;
}
Aggregations