Search in sources :

Example 11 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class TrainResnetWithCifar10 method getDataset.

private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
    Pipeline pipeline = new Pipeline(new ToTensor(), new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD));
    Cifar10 cifar10 = Cifar10.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true).optLimit(arguments.getLimit()).optPipeline(pipeline).build();
    cifar10.prepare(new ProgressBar());
    return cifar10;
}
Also used : Normalize(ai.djl.modality.cv.transform.Normalize) ToTensor(ai.djl.modality.cv.transform.ToTensor) Cifar10(ai.djl.basicdataset.cv.classification.Cifar10) ProgressBar(ai.djl.training.util.ProgressBar) Pipeline(ai.djl.translate.Pipeline)

Example 12 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class TrainPikachu method getDataset.

private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
    Pipeline pipeline = new Pipeline(new ToTensor());
    PikachuDetection pikachuDetection = PikachuDetection.builder().optUsage(usage).optLimit(arguments.getLimit()).optPipeline(pipeline).setSampling(arguments.getBatchSize(), true).build();
    pikachuDetection.prepare(new ProgressBar());
    return pikachuDetection;
}
Also used : ToTensor(ai.djl.modality.cv.transform.ToTensor) PikachuDetection(ai.djl.basicdataset.cv.PikachuDetection) ProgressBar(ai.djl.training.util.ProgressBar) Pipeline(ai.djl.translate.Pipeline)

Example 13 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class ProfilerTest method testProfiler.

@Test
public void testProfiler() throws MalformedModelException, ModelNotFoundException, IOException, TranslateException {
    try (NDManager manager = NDManager.newBaseManager()) {
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).build();
        Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optApplication(Application.CV.IMAGE_CLASSIFICATION).optFilter("layers", "18").optTranslator(translator).optProgress(new ProgressBar()).build();
        String outputFile = "build/profile.json";
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
            Predictor<Image, Classifications> predictor = model.newPredictor()) {
            Image image = ImageFactory.getInstance().fromNDArray(manager.zeros(new Shape(3, 224, 224), DataType.UINT8));
            JniUtils.startProfile(false, true, true);
            predictor.predict(image);
            JniUtils.stopProfile(outputFile);
        }
        Assert.assertTrue(Files.exists(Paths.get(outputFile)), "The profiler file not found!");
    }
}
Also used : ImageClassificationTranslator(ai.djl.modality.cv.translator.ImageClassificationTranslator) Classifications(ai.djl.modality.Classifications) Shape(ai.djl.ndarray.types.Shape) ToTensor(ai.djl.modality.cv.transform.ToTensor) NDManager(ai.djl.ndarray.NDManager) Image(ai.djl.modality.cv.Image) ProgressBar(ai.djl.training.util.ProgressBar) Test(org.testng.annotations.Test)

Example 14 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class TrainPikachu method predict.

public static int predict(String outputDir, String imageFile) throws IOException, MalformedModelException, TranslateException {
    try (Model model = Model.newInstance("pikachu-ssd")) {
        float detectionThreshold = 0.6f;
        // load parameters back to original training block
        model.setBlock(getSsdTrainBlock());
        model.load(Paths.get(outputDir));
        // append prediction logic at end of training block with parameter loaded
        Block ssdTrain = model.getBlock();
        model.setBlock(getSsdPredictBlock(ssdTrain));
        Path imagePath = Paths.get(imageFile);
        SingleShotDetectionTranslator translator = SingleShotDetectionTranslator.builder().addTransform(new ToTensor()).optSynset(Collections.singletonList("pikachu")).optThreshold(detectionThreshold).build();
        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor(translator)) {
            Image image = ImageFactory.getInstance().fromFile(imagePath);
            DetectedObjects detectedObjects = predictor.predict(image);
            image.drawBoundingBoxes(detectedObjects);
            Path out = Paths.get(outputDir).resolve("pikachu_output.png");
            image.save(Files.newOutputStream(out), "png");
            // return number of pikachu detected
            return detectedObjects.getNumberOfObjects();
        }
    }
}
Also used : Path(java.nio.file.Path) SingleShotDetectionTranslator(ai.djl.modality.cv.translator.SingleShotDetectionTranslator) ToTensor(ai.djl.modality.cv.transform.ToTensor) Model(ai.djl.Model) Block(ai.djl.nn.Block) LambdaBlock(ai.djl.nn.LambdaBlock) SequentialBlock(ai.djl.nn.SequentialBlock) DetectedObjects(ai.djl.modality.cv.output.DetectedObjects) Image(ai.djl.modality.cv.Image)

Example 15 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class TrainResnetWithCifar10 method testSaveParameters.

private static Classifications testSaveParameters(Block block, Path path) throws IOException, ModelException, TranslateException {
    String synsetUrl = "https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset_cifar10.txt";
    ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).addTransform(new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD)).optSynsetUrl(synsetUrl).optApplySoftmax(true).build();
    Image img = ImageFactory.getInstance().fromUrl("src/test/resources/airplane1.png");
    Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optModelPath(path).optTranslator(translator).optBlock(block).optModelName("resnetv1").build();
    try (ZooModel<Image, Classifications> model = criteria.loadModel();
        Predictor<Image, Classifications> predictor = model.newPredictor()) {
        return predictor.predict(img);
    }
}
Also used : ImageClassificationTranslator(ai.djl.modality.cv.translator.ImageClassificationTranslator) Normalize(ai.djl.modality.cv.transform.Normalize) Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) Image(ai.djl.modality.cv.Image)

Aggregations

ToTensor (ai.djl.modality.cv.transform.ToTensor)23 Image (ai.djl.modality.cv.Image)15 Classifications (ai.djl.modality.Classifications)12 Resize (ai.djl.modality.cv.transform.Resize)11 ImageClassificationTranslator (ai.djl.modality.cv.translator.ImageClassificationTranslator)6 ProgressBar (ai.djl.training.util.ProgressBar)6 Pipeline (ai.djl.translate.Pipeline)6 Path (java.nio.file.Path)6 Model (ai.djl.Model)5 ImageFolder (ai.djl.basicdataset.cv.classification.ImageFolder)4 Test (org.testng.annotations.Test)4 Normalize (ai.djl.modality.cv.transform.Normalize)3 NDArray (ai.djl.ndarray.NDArray)3 Repository (ai.djl.repository.Repository)3 ModelException (ai.djl.ModelException)2 Cifar10 (ai.djl.basicdataset.cv.classification.Cifar10)2 Mlp (ai.djl.basicmodelzoo.basic.Mlp)2 ImageFactory (ai.djl.modality.cv.ImageFactory)2 NDManager (ai.djl.ndarray.NDManager)2 TranslateException (ai.djl.translate.TranslateException)2