use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class OrtModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
setModelDir(modelPath);
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
modelFile = findModelFile(modelDir.toFile().getName());
if (modelFile == null) {
throw new FileNotFoundException(".onnx file not found in: " + modelPath);
}
}
try {
SessionOptions ortOptions = getSessionOptions(options);
OrtSession session = env.createSession(modelFile.toString(), ortOptions);
block = new OrtSymbolBlock(session, (OrtNDManager) manager);
} catch (OrtException e) {
throw new MalformedModelException("ONNX Model cannot be loaded", e);
}
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class OrtModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(InputStream is, Map<String, ?> options) throws IOException, MalformedModelException {
if (block != null) {
throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
}
modelDir = Files.createTempDirectory("ort-model");
modelDir.toFile().deleteOnExit();
try {
byte[] buf = Utils.toByteArray(is);
SessionOptions ortOptions = getSessionOptions(options);
OrtSession session = env.createSession(buf, ortOptions);
block = new OrtSymbolBlock(session, (OrtNDManager) manager);
} catch (OrtException e) {
throw new MalformedModelException("ONNX Model cannot be loaded", e);
}
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class TfModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws FileNotFoundException, MalformedModelException {
setModelDir(modelPath);
if (prefix == null) {
prefix = modelName;
}
Path exportDir = findModelDir(prefix);
if (exportDir == null) {
exportDir = findModelDir("saved_model.pb");
if (exportDir == null) {
throw new FileNotFoundException("No TensorFlow model found in: " + modelDir);
}
}
String[] tags = null;
ConfigProto configProto = null;
RunOptions runOptions = null;
String signatureDefKey = DEFAULT_SERVING_SIGNATURE_DEF_KEY;
if (options != null) {
Object tagOption = options.get("Tags");
if (tagOption instanceof String[]) {
tags = (String[]) tagOption;
} else if (tagOption instanceof String) {
if (((String) tagOption).isEmpty()) {
tags = new String[0];
} else {
tags = ((String) tagOption).split(",");
}
}
Object config = options.get("ConfigProto");
if (config instanceof ConfigProto) {
configProto = (ConfigProto) config;
} else if (config instanceof String) {
try {
byte[] buf = Base64.getDecoder().decode((String) config);
configProto = ConfigProto.parseFrom(buf);
} catch (InvalidProtocolBufferException e) {
throw new MalformedModelException("Invalid ConfigProto: " + config, e);
}
}
Object run = options.get("RunOptions");
if (run instanceof RunOptions) {
runOptions = (RunOptions) run;
} else if (run instanceof String) {
try {
byte[] buf = Base64.getDecoder().decode((String) run);
runOptions = RunOptions.parseFrom(buf);
} catch (InvalidProtocolBufferException e) {
throw new MalformedModelException("Invalid RunOptions: " + run, e);
}
}
if (options.containsKey("SignatureDefKey")) {
signatureDefKey = (String) options.get("SignatureDefKey");
}
}
if (tags == null) {
tags = new String[] { "serve" };
}
if (configProto == null) {
// default one
configProto = JavacppUtils.getSessionConfig();
}
SavedModelBundle bundle = JavacppUtils.loadSavedModelBundle(exportDir.toString(), tags, configProto, runOptions);
block = new TfSymbolBlock(bundle, signatureDefKey);
}
use of ai.djl.MalformedModelException in project djl by deepjavalibrary.
the class FtModel method load.
/**
* {@inheritDoc}
*/
@Override
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, MalformedModelException {
if (Files.notExists(modelPath)) {
throw new FileNotFoundException("Model directory doesn't exist: " + modelPath.toAbsolutePath());
}
modelDir = modelPath.toAbsolutePath();
Path modelFile = findModelFile(prefix);
if (modelFile == null) {
modelFile = findModelFile(modelDir.toFile().getName());
if (modelFile == null) {
throw new FileNotFoundException("No .ftz or .bin file found in : " + modelPath);
}
}
String modelFilePath = modelFile.toString();
FtWrapper fta = FtWrapper.newInstance();
if (!fta.checkModel(modelFilePath)) {
throw new MalformedModelException("Malformed FastText model file:" + modelFilePath);
}
fta.loadModel(modelFilePath);
if (options != null) {
for (Map.Entry<String, ?> entry : options.entrySet()) {
properties.put(entry.getKey(), entry.getValue().toString());
}
}
String modelType = fta.getModelType();
properties.put("model-type", modelType);
if ("sup".equals(modelType)) {
String labelPrefix = properties.getOrDefault("label-prefix", FtTextClassification.DEFAULT_LABEL_PREFIX);
block = new FtTextClassification(fta, labelPrefix);
modelDir = block.getModelFile();
} else if ("cbow".equals(modelType) || "sg".equals(modelType)) {
block = new FtWordEmbeddingBlock(fta);
modelDir = block.getModelFile();
} else {
throw new MalformedModelException("Unexpected FastText model type: " + modelType);
}
}
use of ai.djl.MalformedModelException in project build-your-own-social-media-analytics-with-apache-kafka by scholzj.
the class TopologyProducer method buildTopology.
@Produces
public Topology buildTopology() {
final TweetSerde tweetSerde = new TweetSerde();
try {
Criteria<String, Classifications> criteria = Criteria.builder().optApplication(Application.NLP.SENTIMENT_ANALYSIS).setTypes(String.class, Classifications.class).build();
predictor = ModelZoo.loadModel(criteria).newPredictor();
} catch (IOException | ModelNotFoundException | MalformedModelException e) {
LOG.error("Failed to load model", e);
throw new RuntimeException("Failed to load model", e);
}
final StreamsBuilder builder = new StreamsBuilder();
builder.stream(SOURCE_TOPIC, Consumed.with(Serdes.ByteArray(), tweetSerde)).flatMapValues(value -> {
if (value.getRetweetedStatus() != null) {
// We ignore retweets => we do not want alert for every retweet
return List.of();
} else {
String tweet = value.getText();
try {
Classifications classifications = predictor.predict(tweet);
String statusUrl = "https://twitter.com/" + value.getUser().getScreenName() + "/status/" + value.getId();
String alert = String.format("The following tweet was classified as %s with %2.2f%% probability: %s", classifications.best().getClassName().toLowerCase(Locale.ENGLISH), classifications.best().getProbability() * 100, statusUrl);
// We care nly about strong results where probability is > 50%
if (classifications.best().getProbability() > 0.50) {
LOG.infov("Tweeting: {0}", alert);
return List.of(alert);
} else {
LOG.infov("Not tweeting: {0}", alert);
return List.of();
}
} catch (TranslateException e) {
LOG.errorv("Failed to classify the tweet {0}", value);
return List.of();
}
}
}).peek((key, value) -> LOG.infov("{0}", value)).to(TARGET_TOPIC, Produced.with(Serdes.ByteArray(), Serdes.String()));
return builder.build();
}
Aggregations