Search in sources :

Example 16 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class ImageClassification method predict.

public static Classifications predict() throws IOException, ModelException, TranslateException {
    Path imageFile = Paths.get("src/test/resources/0.png");
    Image img = ImageFactory.getInstance().fromFile(imageFile);
    String modelName = "mlp";
    try (Model model = Model.newInstance(modelName)) {
        model.setBlock(new Mlp(28 * 28, 10, new int[] { 128, 64 }));
        // Assume you have run TrainMnist.java example, and saved model in build/model folder.
        Path modelDir = Paths.get("build/model");
        model.load(modelDir);
        List<String> classes = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).optSynset(classes).optApplySoftmax(true).build();
        try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
            return predictor.predict(img);
        }
    }
}
Also used : Path(java.nio.file.Path) Mlp(ai.djl.basicmodelzoo.basic.Mlp) Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) Model(ai.djl.Model) Image(ai.djl.modality.cv.Image)

Example 17 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class TrtTest method testTrtUff.

@Test
public void testTrtUff() throws ModelException, IOException, TranslateException {
    Engine engine;
    try {
        engine = Engine.getEngine("TensorRT");
    } catch (Exception ignore) {
        throw new SkipException("Your os configuration doesn't support TensorRT.");
    }
    if (!engine.defaultDevice().isGpu()) {
        throw new SkipException("TensorRT only support GPU.");
    }
    List<String> synset = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
    ImageClassificationTranslator translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).optSynset(synset).optApplySoftmax(true).addTransform(new ToTensor()).optBatchifier(null).build();
    Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls("https://resources.djl.ai/test-models/tensorrt/lenet5.zip").optTranslator(translator).optEngine("TensorRT").build();
    try (ZooModel<Image, Classifications> model = criteria.loadModel();
        Predictor<Image, Classifications> predictor = model.newPredictor()) {
        Path path = Paths.get("../../examples/src/test/resources/0.png");
        Image image = ImageFactory.getInstance().fromFile(path);
        Classifications ret = predictor.predict(image);
        Assert.assertEquals(ret.best().getClassName(), "0");
    }
}
Also used : Path(java.nio.file.Path) ImageClassificationTranslator(ai.djl.modality.cv.translator.ImageClassificationTranslator) Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) Image(ai.djl.modality.cv.Image) SkipException(org.testng.SkipException) ModelException(ai.djl.ModelException) TranslateException(ai.djl.translate.TranslateException) IOException(java.io.IOException) SkipException(org.testng.SkipException) Engine(ai.djl.engine.Engine) Test(org.testng.annotations.Test)

Example 18 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class MyTranslator method processInput.

@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
    byte[] data = input.getAsBytes(0);
    ImageFactory factory = ImageFactory.getInstance();
    Image image = factory.fromInputStream(new ByteArrayInputStream(data));
    NDArray array = image.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
    Pipeline pipeline = new Pipeline();
    pipeline.add(new CenterCrop());
    pipeline.add(new Resize(28, 28));
    pipeline.add(new ToTensor());
    return pipeline.transform(new NDList(array));
}
Also used : ImageFactory(ai.djl.modality.cv.ImageFactory) ToTensor(ai.djl.modality.cv.transform.ToTensor) ByteArrayInputStream(java.io.ByteArrayInputStream) Resize(ai.djl.modality.cv.transform.Resize) NDList(ai.djl.ndarray.NDList) CenterCrop(ai.djl.modality.cv.transform.CenterCrop) NDArray(ai.djl.ndarray.NDArray) Image(ai.djl.modality.cv.Image) Pipeline(ai.djl.translate.Pipeline)

Example 19 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.

the class DoodleController method model.

@Bean
public ZooModel<Image, Classifications> model() throws ModelException, IOException {
    Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().optFlag(Image.Flag.GRAYSCALE).setPipeline(new Pipeline(new ToTensor())).optApplySoftmax(true).build();
    Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelUrls(MODEL_URL).optTranslator(translator).build();
    return criteria.loadModel();
}
Also used : Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) Image(ai.djl.modality.cv.Image) Pipeline(ai.djl.translate.Pipeline) Bean(org.springframework.context.annotation.Bean)

Example 20 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.

the class Training method initDataset.

private static ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException {
    ImageFolder dataset = ImageFolder.builder().setRepositoryPath(Paths.get(datasetRoot)).optMaxDepth(10).addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)).addTransform(new ToTensor()).setSampling(BATCH_SIZE, true).build();
    dataset.prepare();
    return dataset;
}
Also used : ImageFolder(ai.djl.basicdataset.cv.classification.ImageFolder) ToTensor(ai.djl.modality.cv.transform.ToTensor) Resize(ai.djl.modality.cv.transform.Resize)

Aggregations

ToTensor (ai.djl.modality.cv.transform.ToTensor)23 Image (ai.djl.modality.cv.Image)15 Classifications (ai.djl.modality.Classifications)12 Resize (ai.djl.modality.cv.transform.Resize)11 ImageClassificationTranslator (ai.djl.modality.cv.translator.ImageClassificationTranslator)6 ProgressBar (ai.djl.training.util.ProgressBar)6 Pipeline (ai.djl.translate.Pipeline)6 Path (java.nio.file.Path)6 Model (ai.djl.Model)5 ImageFolder (ai.djl.basicdataset.cv.classification.ImageFolder)4 Test (org.testng.annotations.Test)4 Normalize (ai.djl.modality.cv.transform.Normalize)3 NDArray (ai.djl.ndarray.NDArray)3 Repository (ai.djl.repository.Repository)3 ModelException (ai.djl.ModelException)2 Cifar10 (ai.djl.basicdataset.cv.classification.Cifar10)2 Mlp (ai.djl.basicmodelzoo.basic.Mlp)2 ImageFactory (ai.djl.modality.cv.ImageFactory)2 NDManager (ai.djl.ndarray.NDManager)2 TranslateException (ai.djl.translate.TranslateException)2