use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class Parameter method load.
/**
* Loads parameter NDArrays from InputStream.
*
* <p>Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray
* will be loaded as NDArray only.
*
* @param manager the NDManager
* @param dis the InputStream
* @throws IOException if failed to read
* @throws MalformedModelException Exception thrown when model is not in expected format
* (parameters).
*/
public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException {
char magic = dis.readChar();
if (magic == 'N') {
return;
} else if (magic != 'P') {
throw new MalformedModelException("Invalid input data.");
}
// Version
byte version = dis.readByte();
if (version != VERSION) {
throw new MalformedModelException("Unsupported encoding version: " + version);
}
String parameterName = dis.readUTF();
if (!parameterName.equals(getName())) {
throw new MalformedModelException("Unexpected parameter name: " + parameterName + ", expected: " + name);
}
array = manager.decode(dis);
// set the shape of the parameter and prepare() can be skipped
shape = array.getShape();
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class Criteria method loadModel.
/**
* Load the {@link ZooModel} that matches this criteria.
*
* @return the model that matches the criteria
* @throws IOException for various exceptions loading data from the repository
* @throws ModelNotFoundException if no model with the specified criteria is found
* @throws MalformedModelException if the model data is malformed
*/
public ZooModel<I, O> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
Logger logger = LoggerFactory.getLogger(ModelZoo.class);
logger.debug("Loading model with {}", this);
List<ModelZoo> list = new ArrayList<>();
if (modelZoo != null) {
logger.debug("Searching model in specified model zoo: {}", modelZoo.getGroupId());
if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
throw new ModelNotFoundException("groupId conflict with ModelZoo criteria." + modelZoo.getGroupId() + " v.s. " + groupId);
}
Set<String> supportedEngine = modelZoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
throw new ModelNotFoundException("ModelZoo doesn't support specified engine: " + engine);
}
list.add(modelZoo);
} else {
for (ModelZoo zoo : ModelZoo.listModelZoo()) {
if (groupId != null && !zoo.getGroupId().equals(groupId)) {
// filter out ModelZoo by groupId
logger.debug("Ignore ModelZoo {} by groupId: {}", zoo.getGroupId(), groupId);
continue;
}
Set<String> supportedEngine = zoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
logger.debug("Ignore ModelZoo {} by engine: {}", zoo.getGroupId(), engine);
continue;
}
list.add(zoo);
}
}
Exception lastException = null;
for (ModelZoo zoo : list) {
String loaderGroupId = zoo.getGroupId();
for (ModelLoader loader : zoo.getModelLoaders()) {
Application app = loader.getApplication();
String loaderArtifactId = loader.getArtifactId();
logger.debug("Checking ModelLoader: {}", loader);
if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
// filter out by model loader artifactId
logger.debug("artifactId mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
continue;
}
if (application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(application)) {
// filter out ModelLoader by application
logger.debug("application mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
continue;
}
try {
return loader.loadModel(this);
} catch (ModelNotFoundException e) {
lastException = e;
logger.trace("", e);
logger.debug("{} for ModelLoader: {}:{}", e.getMessage(), loaderGroupId, loaderArtifactId);
}
}
}
throw new ModelNotFoundException("No matching model with specified Input/Output type found.", lastException);
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class MxSymbolBlock method loadParameters.
/**
* {@inheritDoc}
*/
@Override
public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
byte version = is.readByte();
if (version > VERSION) {
throw new MalformedModelException("Unsupported encoding version: " + version);
}
if (version < VERSION && symbol == null) {
throw new IllegalStateException("Symbol is required for version 2, please use Model to load");
}
if (version == VERSION) {
int len = is.readInt();
byte[] bytes = new byte[len];
if (is.read(bytes) == -1) {
throw new MalformedModelException("InputStream ends at symbol loading!");
}
// init block only if it is not set
symbol = Symbol.loadJson((MxNDManager) manager, new String(bytes, StandardCharsets.UTF_8));
initBlock();
}
int size = is.readInt();
for (int i = 0; i < size; ++i) {
inputNames.add(is.readUTF());
}
for (Parameter parameter : mxNetParams) {
parameter.load(this.manager, is);
}
setInputNames(inputNames);
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class OrtModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(InputStream is, Map<String, ?> options) throws IOException, MalformedModelException {
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
modelDir = Files.createTempDirectory("ort-model");
modelDir.toFile().deleteOnExit();
try {
byte[] buf = Utils.toByteArray(is);
SessionOptions ortOptions = getSessionOptions(options);
Device device = manager.getDevice();
if (device.isGpu()) {
ortOptions.addCUDA(manager.getDevice().getDeviceId());
}
OrtSession session = env.createSession(buf, ortOptions);
block = new OrtSymbolBlock(session, (OrtNDManager) manager);
} catch (OrtException e) {
throw new MalformedModelException("ONNX Model cannot be loaded", e);
}
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class OrtModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
setModelDir(modelPath);
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
modelFile = findModelFile(modelDir.toFile().getName());
if (modelFile == null) {
throw new FileNotFoundException(".onnx file not found in: " + modelPath);
}
}
try {
SessionOptions ortOptions = getSessionOptions(options);
Device device = manager.getDevice();
if (device.isGpu()) {
ortOptions.addCUDA(manager.getDevice().getDeviceId());
}
OrtSession session = env.createSession(modelFile.toString(), ortOptions);
block = new OrtSymbolBlock(session, (OrtNDManager) manager);
} catch (OrtException e) {
throw new MalformedModelException("ONNX Model cannot be loaded", e);
}
}
Aggregations