Search in sources :

Example 1 with MalformedModelException

use of ai.djl.MalformedModelException in project djl by deepjavalibrary.

the class Parameter method load.

/**
 * Loads parameter NDArrays from InputStream.
 *
 * <p>Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray
 * will be loaded as NDArray only.
 *
 * @param manager the NDManager
 * @param dis the InputStream
 * @throws IOException if failed to read
 * @throws MalformedModelException Exception thrown when model is not in expected format
 *     (parameters).
 */
public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException {
    char magic = dis.readChar();
    if (magic == 'N') {
        return;
    } else if (magic != 'P') {
        throw new MalformedModelException("Invalid input data.");
    }
    // Version
    byte version = dis.readByte();
    if (version != VERSION) {
        throw new MalformedModelException("Unsupported encoding version: " + version);
    }
    String parameterName = dis.readUTF();
    if (!parameterName.equals(getName())) {
        throw new MalformedModelException("Unexpected parameter name: " + parameterName + ", expected: " + name);
    }
    array = manager.decode(dis);
    // set the shape of the parameter and prepare() can be skipped
    shape = array.getShape();
}
Also used : MalformedModelException(ai.djl.MalformedModelException)

Example 2 with MalformedModelException

use of ai.djl.MalformedModelException in project djl by deepjavalibrary.

the class Criteria method loadModel.

/**
 * Load the {@link ZooModel} that matches this criteria.
 *
 * @return the model that matches the criteria
 * @throws IOException for various exceptions loading data from the repository
 * @throws ModelNotFoundException if no model with the specified criteria is found
 * @throws MalformedModelException if the model data is malformed
 */
public ZooModel<I, O> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
    Logger logger = LoggerFactory.getLogger(ModelZoo.class);
    logger.debug("Loading model with {}", this);
    List<ModelZoo> list = new ArrayList<>();
    if (modelZoo != null) {
        logger.debug("Searching model in specified model zoo: {}", modelZoo.getGroupId());
        if (groupId != null && !modelZoo.getGroupId().equals(groupId)) {
            throw new ModelNotFoundException("groupId conflict with ModelZoo criteria." + modelZoo.getGroupId() + " v.s. " + groupId);
        }
        Set<String> supportedEngine = modelZoo.getSupportedEngines();
        if (engine != null && !supportedEngine.contains(engine)) {
            throw new ModelNotFoundException("ModelZoo doesn't support specified engine: " + engine);
        }
        list.add(modelZoo);
    } else {
        for (ModelZoo zoo : ModelZoo.listModelZoo()) {
            if (groupId != null && !zoo.getGroupId().equals(groupId)) {
                // filter out ModelZoo by groupId
                logger.debug("Ignore ModelZoo {} by groupId: {}", zoo.getGroupId(), groupId);
                continue;
            }
            Set<String> supportedEngine = zoo.getSupportedEngines();
            if (engine != null && !supportedEngine.contains(engine)) {
                logger.debug("Ignore ModelZoo {} by engine: {}", zoo.getGroupId(), engine);
                continue;
            }
            list.add(zoo);
        }
    }
    Exception lastException = null;
    for (ModelZoo zoo : list) {
        String loaderGroupId = zoo.getGroupId();
        for (ModelLoader loader : zoo.getModelLoaders()) {
            Application app = loader.getApplication();
            String loaderArtifactId = loader.getArtifactId();
            logger.debug("Checking ModelLoader: {}", loader);
            if (artifactId != null && !artifactId.equals(loaderArtifactId)) {
                // filter out by model loader artifactId
                logger.debug("artifactId mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
                continue;
            }
            if (application != Application.UNDEFINED && app != Application.UNDEFINED && !app.matches(application)) {
                // filter out ModelLoader by application
                logger.debug("application mismatch for ModelLoader: {}:{}", loaderGroupId, loaderArtifactId);
                continue;
            }
            try {
                return loader.loadModel(this);
            } catch (ModelNotFoundException e) {
                lastException = e;
                logger.trace("", e);
                logger.debug("{} for ModelLoader: {}:{}", e.getMessage(), loaderGroupId, loaderArtifactId);
            }
        }
    }
    throw new ModelNotFoundException("No matching model with specified Input/Output type found.", lastException);
}
Also used : ArrayList(java.util.ArrayList) Logger(org.slf4j.Logger) Application(ai.djl.Application) MalformedURLException(java.net.MalformedURLException) IOException(java.io.IOException) MalformedModelException(ai.djl.MalformedModelException)

Example 3 with MalformedModelException

use of ai.djl.MalformedModelException in project djl by deepjavalibrary.

the class MxSymbolBlock method loadParameters.

/**
 * {@inheritDoc}
 */
@Override
public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
    byte version = is.readByte();
    if (version > VERSION) {
        throw new MalformedModelException("Unsupported encoding version: " + version);
    }
    if (version < VERSION && symbol == null) {
        throw new IllegalStateException("Symbol is required for version 2, please use Model to load");
    }
    if (version == VERSION) {
        int len = is.readInt();
        byte[] bytes = new byte[len];
        if (is.read(bytes) == -1) {
            throw new MalformedModelException("InputStream ends at symbol loading!");
        }
        // init block only if it is not set
        symbol = Symbol.loadJson((MxNDManager) manager, new String(bytes, StandardCharsets.UTF_8));
        initBlock();
    }
    int size = is.readInt();
    for (int i = 0; i < size; ++i) {
        inputNames.add(is.readUTF());
    }
    for (Parameter parameter : mxNetParams) {
        parameter.load(this.manager, is);
    }
    setInputNames(inputNames);
}
Also used : MalformedModelException(ai.djl.MalformedModelException) Parameter(ai.djl.nn.Parameter)

Example 4 with MalformedModelException

use of ai.djl.MalformedModelException in project djl by deepjavalibrary.

the class OrtModel method load.

/**
 * {@inheritDoc}
 */
@Override
public void load(InputStream is, Map<String, ?> options) throws IOException, MalformedModelException {
    if (block != null) {
        throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
    }
    modelDir = Files.createTempDirectory("ort-model");
    modelDir.toFile().deleteOnExit();
    try {
        byte[] buf = Utils.toByteArray(is);
        SessionOptions ortOptions = getSessionOptions(options);
        Device device = manager.getDevice();
        if (device.isGpu()) {
            ortOptions.addCUDA(manager.getDevice().getDeviceId());
        }
        OrtSession session = env.createSession(buf, ortOptions);
        block = new OrtSymbolBlock(session, (OrtNDManager) manager);
    } catch (OrtException e) {
        throw new MalformedModelException("ONNX Model cannot be loaded", e);
    }
}
Also used : OrtSession(ai.onnxruntime.OrtSession) Device(ai.djl.Device) SessionOptions(ai.onnxruntime.OrtSession.SessionOptions) MalformedModelException(ai.djl.MalformedModelException) OrtException(ai.onnxruntime.OrtException)

Example 5 with MalformedModelException

use of ai.djl.MalformedModelException in project djl by deepjavalibrary.

the class OrtModel method load.

/**
 * {@inheritDoc}
 */
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
    setModelDir(modelPath);
    if (block != null) {
        throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
    }
    Path modelFile = findModelFile(prefix);
    if (modelFile == null) {
        modelFile = findModelFile(modelDir.toFile().getName());
        if (modelFile == null) {
            throw new FileNotFoundException(".onnx file not found in: " + modelPath);
        }
    }
    try {
        SessionOptions ortOptions = getSessionOptions(options);
        Device device = manager.getDevice();
        if (device.isGpu()) {
            ortOptions.addCUDA(manager.getDevice().getDeviceId());
        }
        OrtSession session = env.createSession(modelFile.toString(), ortOptions);
        block = new OrtSymbolBlock(session, (OrtNDManager) manager);
    } catch (OrtException e) {
        throw new MalformedModelException("ONNX Model cannot be loaded", e);
    }
}
Also used : Path(java.nio.file.Path) OrtSession(ai.onnxruntime.OrtSession) Device(ai.djl.Device) SessionOptions(ai.onnxruntime.OrtSession.SessionOptions) FileNotFoundException(java.io.FileNotFoundException) MalformedModelException(ai.djl.MalformedModelException) OrtException(ai.onnxruntime.OrtException)

Aggregations

MalformedModelException (ai.djl.MalformedModelException)9 FileNotFoundException (java.io.FileNotFoundException)3 IOException (java.io.IOException)3 Path (java.nio.file.Path)3 Application (ai.djl.Application)2 Device (ai.djl.Device)2 ModelNotFoundException (ai.djl.repository.zoo.ModelNotFoundException)2 OrtException (ai.onnxruntime.OrtException)2 OrtSession (ai.onnxruntime.OrtSession)2 SessionOptions (ai.onnxruntime.OrtSession.SessionOptions)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 FtWrapper (ai.djl.fasttext.jni.FtWrapper)1 FtTextClassification (ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification)1 FtWordEmbeddingBlock (ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock)1 Predictor (ai.djl.inference.Predictor)1 Classifications (ai.djl.modality.Classifications)1 Input (ai.djl.modality.Input)1 Output (ai.djl.modality.Output)1 Parameter (ai.djl.nn.Parameter)1 Criteria (ai.djl.repository.zoo.Criteria)1