use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class ImageClassification method predict.
public static Classifications predict() throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/test/resources/0.png");
Image img = ImageFactory.getInstance().fromFile(imageFile);
String modelName = "mlp";
try (Model model = Model.newInstance(modelName)) {
model.setBlock(new Mlp(28 * 28, 10, new int[] { 128, 64 }));
// Assume you have run TrainMnist.java example, and saved model in build/model folder.
Path modelDir = Paths.get("build/model");
model.load(modelDir);
List<String> classes = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optSynset(classes).optApplySoftmax(true).build();
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
return predictor.predict(img);
}
}
}
use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class TrtTest method testTrtUff.
@Test
public void testTrtUff() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
if (!engine.defaultDevice().isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
List<String> synset = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
ImageClassificationTranslator translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).optSynset(synset).optApplySoftmax(true).addTransform(new ToTensor()).optBatchifier(null).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://resources.djl.ai/test-models/tensorrt/lenet5.zip").optTranslator(translator).optEngine("TensorRT").build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Path path = Paths.get("../../examples/src/test/resources/0.png");
Image image = ImageFactory.getInstance().fromFile(path);
Classifications ret = predictor.predict(image);
Assert.assertEquals(ret.best().getClassName(), "0");
}
}
use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class MyTranslator method processInput.
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
byte[] data = input.getAsBytes(0);
ImageFactory factory = ImageFactory.getInstance();
Image image = factory.fromInputStream(new ByteArrayInputStream(data));
NDArray array = image.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop());
pipeline.add(new Resize(28, 28));
pipeline.add(new ToTensor());
return pipeline.transform(new NDList(array));
}
use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class DoodleController method model.
@Bean
public ZooModel<Image, Classifications> model() throws ModelException, IOException {
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).setPipeline(new Pipeline(new ToTensor())).optApplySoftmax(true).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls(MODEL_URL).optTranslator(translator).build();
return criteria.loadModel();
}
use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class Training method initDataset.
private static ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException {
ImageFolder dataset = ImageFolder.builder().setRepositoryPath(Paths.get(datasetRoot)).optMaxDepth(10).addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)).addTransform(new ToTensor()).setSampling(BATCH_SIZE, true).build();
dataset.prepare();
return dataset;
}
Aggregations