use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class DoodleModel method loadModel.
public static ZooModel<Image, Classifications> loadModel() throws ModelException, IOException {
ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optFlag(Image.Flag.GRAYSCALE).optApplySoftmax(true).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip").optOption("mapLocation", "true").optTranslator(translator).build();
return ModelZoo.loadModel(criteria);
}
use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class ImageClassification method predict.
public static Classifications predict(Image img) throws IOException, ModelException, TranslateException {
String modelName = "mlp";
try (Model model = Model.newInstance(modelName)) {
model.setBlock(new Mlp(28 * 28, 10, new int[] { 128, 64 }));
// Assume you have run TrainMnist.java example, and saved model in build/model folder.
Path modelDir = Paths.get("model/mnist");
Files.createDirectories(modelDir);
Path mlp = modelDir.resolve("mlp-0000.params");
if (!Files.exists(mlp)) {
String url = "https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/zoo/mlp/0.0.3/mlp-0000.params.gz";
DownloadUtils.download(url, "model/mnist/mlp-0000.params");
}
model.load(modelDir);
List<String> classes = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optSynset(classes).build();
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
return predictor.predict(img);
}
}
}
use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class Handler method handleRequest.
@Override
public void handleRequest(InputStream is, OutputStream os, Context context) throws IOException {
LambdaLogger logger = context.getLogger();
String input = Utils.toString(is);
try {
Request request = GSON.fromJson(input, Request.class);
String base64Img = request.getImageData().split(",")[1];
byte[] imgBytes = Base64.getDecoder().decode(base64Img);
Image img;
try (ByteArrayInputStream bis = new ByteArrayInputStream(imgBytes)) {
ImageFactory factory = ImageFactory.getInstance();
img = factory.fromInputStream(bis);
}
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optFlag(Image.Flag.GRAYSCALE).optApplySoftmax(true).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip").optTranslator(translator).build();
ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
List<Classifications.Classification> result = predictor.predict(img).topK(5);
os.write(GSON.toJson(result).getBytes(StandardCharsets.UTF_8));
}
} catch (RuntimeException | ModelException | TranslateException e) {
logger.log("Failed handle input: " + input);
logger.log(e.toString());
String msg = "{\"status\": \"invoke failed: " + e.toString() + "\"}";
os.write(msg.getBytes(StandardCharsets.UTF_8));
}
}
Aggregations