Search in sources :

Example 6 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project PissAI by DxsSucuk.

the class ModelTrainer method getValidationDataSet.

public ImageFolder getValidationDataSet() throws TranslateException, IOException {
    int batchSize = 32;
    // set the image folder path
    Repository repository = Repository.newInstance("folder", Paths.get("validation"));
    ImageFolder dataset = ImageFolder.builder().setRepository(repository).addTransform(new Resize(256, 256)).addTransform(new ToTensor()).setSampling(batchSize, true).build();
    // call prepare before using
    dataset.prepare();
    return dataset;
}
Also used : Repository(ai.djl.repository.Repository) ImageFolder(ai.djl.basicdataset.cv.classification.ImageFolder) ToTensor(ai.djl.modality.cv.transform.ToTensor) Resize(ai.djl.modality.cv.transform.Resize)

Example 7 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class ImageFolderTest method testImageFolder.

@Test
public void testImageFolder() throws IOException, TranslateException {
    Repository repository = Repository.newInstance("test", "src/test/resources/imagefolder");
    TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss());
    try (Model model = Model.newInstance("model")) {
        model.setBlock(Blocks.identityBlock());
        ImageFolder dataset = ImageFolder.builder().setRepository(repository).addTransform(new Resize(100, 100)).addTransform(new ToTensor()).setSampling(1, false).build();
        List<String> synsets = Arrays.asList("cat", "dog", "misc");
        Assert.assertEquals(synsets, dataset.getSynset());
        try (Trainer trainer = model.newTrainer(config)) {
            NDManager manager = trainer.getManager();
            NDArray cat = ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/imagefolder/cat/kitten.jpg")).toNDArray(manager);
            NDArray dog = ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/imagefolder/dog/dog_bike_car.jpg")).toNDArray(manager);
            NDArray pikachu = ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/imagefolder/misc/pikachu.png")).toNDArray(manager);
            Iterator<Batch> ds = trainer.iterateDataset(dataset).iterator();
            Batch catBatch = ds.next();
            Assertions.assertAlmostEquals(catBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(cat, 100, 100)).expandDims(0));
            Assert.assertEquals(catBatch.getLabels().singletonOrThrow(), manager.create(new long[] { 0 }));
            catBatch.close();
            Batch dogBatch = ds.next();
            Assertions.assertAlmostEquals(dogBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(dog, 100, 100)).expandDims(0));
            Assert.assertEquals(dogBatch.getLabels().singletonOrThrow(), manager.create(new long[] { 1 }));
            dogBatch.close();
            Batch pikachuBatch = ds.next();
            Assertions.assertAlmostEquals(pikachuBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(pikachu, 100, 100)).expandDims(0));
            Assert.assertEquals(pikachuBatch.getLabels().singletonOrThrow(), manager.create(new long[] { 2 }));
            pikachuBatch.close();
        }
    }
}
Also used : ToTensor(ai.djl.modality.cv.transform.ToTensor) Resize(ai.djl.modality.cv.transform.Resize) Trainer(ai.djl.training.Trainer) Repository(ai.djl.repository.Repository) ImageFolder(ai.djl.basicdataset.cv.classification.ImageFolder) Batch(ai.djl.training.dataset.Batch) Model(ai.djl.Model) NDArray(ai.djl.ndarray.NDArray) NDManager(ai.djl.ndarray.NDManager) DefaultTrainingConfig(ai.djl.training.DefaultTrainingConfig) TrainingConfig(ai.djl.training.TrainingConfig) DefaultTrainingConfig(ai.djl.training.DefaultTrainingConfig) Test(org.testng.annotations.Test)

Example 8 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class ImageClassificationDataset method makeTranslator.

/**
 * Returns the {@link ImageClassificationTranslator} matching the format of this dataset.
 *
 * @return the {@link ImageClassificationTranslator} matching the format of this dataset
 */
public Translator<Image, Classifications> makeTranslator() {
    Pipeline pipeline = new Pipeline();
    // Resize the image if the image size is fixed
    Optional<Integer> width = getImageWidth();
    Optional<Integer> height = getImageHeight();
    if (width.isPresent() && height.isPresent()) {
        pipeline.add(new Resize(width.get(), height.get()));
    }
    pipeline.add(new ToTensor());
    return ImageClassificationTranslator.builder().optSynset(getClasses()).setPipeline(pipeline).build();
}
Also used : ToTensor(ai.djl.modality.cv.transform.ToTensor) Resize(ai.djl.modality.cv.transform.Resize) Pipeline(ai.djl.translate.Pipeline)

Example 9 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class DlrTest method testDlr.

@Test
public void testDlr() throws ModelException, IOException, TranslateException {
    TestRequirements.notWindows();
    TestRequirements.notArm();
    String os;
    if (System.getProperty("os.name").toLowerCase().startsWith("mac")) {
        os = "osx";
    } else if (System.getProperty("os.name").toLowerCase().startsWith("linux")) {
        os = "linux";
    } else {
        throw new IllegalStateException("Unexpected os");
    }
    ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new Resize(224, 224)).addTransform(new ToTensor()).build();
    Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optApplication(Application.CV.IMAGE_CLASSIFICATION).optFilter("layers", "50").optFilter("os", os).optTranslator(translator).optEngine("DLR").optProgress(new ProgressBar()).build();
    Path file = Paths.get("../../../examples/src/test/resources/kitten.jpg");
    Image image = ImageFactory.getInstance().fromFile(file);
    try (ZooModel<Image, Classifications> model = criteria.loadModel();
        Predictor<Image, Classifications> predictor = model.newPredictor()) {
        Classifications result = predictor.predict(image);
        Assert.assertEquals(result.best().getClassName(), "n02123045 tabby, tabby cat");
    }
}
Also used : Path(java.nio.file.Path) ImageClassificationTranslator(ai.djl.modality.cv.translator.ImageClassificationTranslator) Classifications(ai.djl.modality.Classifications) ToTensor(ai.djl.modality.cv.transform.ToTensor) Resize(ai.djl.modality.cv.transform.Resize) Image(ai.djl.modality.cv.Image) ProgressBar(ai.djl.training.util.ProgressBar) Test(org.testng.annotations.Test)

Example 10 with ToTensor

use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.

the class TrainWithOptimizers method getDataset.

private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
    Pipeline pipeline = new Pipeline(new ToTensor(), new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD));
    Cifar10 cifar10 = Cifar10.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true).optLimit(arguments.getLimit()).optPipeline(pipeline).build();
    cifar10.prepare(new ProgressBar());
    return cifar10;
}
Also used : Normalize(ai.djl.modality.cv.transform.Normalize) ToTensor(ai.djl.modality.cv.transform.ToTensor) Cifar10(ai.djl.basicdataset.cv.classification.Cifar10) ProgressBar(ai.djl.training.util.ProgressBar) Pipeline(ai.djl.translate.Pipeline)

Aggregations

ToTensor (ai.djl.modality.cv.transform.ToTensor)23 Image (ai.djl.modality.cv.Image)15 Classifications (ai.djl.modality.Classifications)12 Resize (ai.djl.modality.cv.transform.Resize)11 ImageClassificationTranslator (ai.djl.modality.cv.translator.ImageClassificationTranslator)6 ProgressBar (ai.djl.training.util.ProgressBar)6 Pipeline (ai.djl.translate.Pipeline)6 Path (java.nio.file.Path)6 Model (ai.djl.Model)5 ImageFolder (ai.djl.basicdataset.cv.classification.ImageFolder)4 Test (org.testng.annotations.Test)4 Normalize (ai.djl.modality.cv.transform.Normalize)3 NDArray (ai.djl.ndarray.NDArray)3 Repository (ai.djl.repository.Repository)3 ModelException (ai.djl.ModelException)2 Cifar10 (ai.djl.basicdataset.cv.classification.Cifar10)2 Mlp (ai.djl.basicmodelzoo.basic.Mlp)2 ImageFactory (ai.djl.modality.cv.ImageFactory)2 NDManager (ai.djl.ndarray.NDManager)2 TranslateException (ai.djl.translate.TranslateException)2