use of ai.djl.modality.cv.transform.ToTensor in project PissAI by DxsSucuk.
the class ModelTrainer method getValidationDataSet.
public ImageFolder getValidationDataSet() throws TranslateException, IOException {
int batchSize = 32;
// set the image folder path
Repository repository = Repository.newInstance("folder", Paths.get("validation"));
ImageFolder dataset = ImageFolder.builder().setRepository(repository).addTransform(new Resize(256, 256)).addTransform(new ToTensor()).setSampling(batchSize, true).build();
// call prepare before using
dataset.prepare();
return dataset;
}
use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class ImageFolderTest method testImageFolder.
@Test
public void testImageFolder() throws IOException, TranslateException {
Repository repository = Repository.newInstance("test", "src/test/resources/imagefolder");
TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss());
try (Model model = Model.newInstance("model")) {
model.setBlock(Blocks.identityBlock());
ImageFolder dataset = ImageFolder.builder().setRepository(repository).addTransform(new Resize(100, 100)).addTransform(new ToTensor()).setSampling(1, false).build();
List<String> synsets = Arrays.asList("cat", "dog", "misc");
Assert.assertEquals(synsets, dataset.getSynset());
try (Trainer trainer = model.newTrainer(config)) {
NDManager manager = trainer.getManager();
NDArray cat = ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/imagefolder/cat/kitten.jpg")).toNDArray(manager);
NDArray dog = ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/imagefolder/dog/dog_bike_car.jpg")).toNDArray(manager);
NDArray pikachu = ImageFactory.getInstance().fromFile(Paths.get("src/test/resources/imagefolder/misc/pikachu.png")).toNDArray(manager);
Iterator<Batch> ds = trainer.iterateDataset(dataset).iterator();
Batch catBatch = ds.next();
Assertions.assertAlmostEquals(catBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(cat, 100, 100)).expandDims(0));
Assert.assertEquals(catBatch.getLabels().singletonOrThrow(), manager.create(new long[] { 0 }));
catBatch.close();
Batch dogBatch = ds.next();
Assertions.assertAlmostEquals(dogBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(dog, 100, 100)).expandDims(0));
Assert.assertEquals(dogBatch.getLabels().singletonOrThrow(), manager.create(new long[] { 1 }));
dogBatch.close();
Batch pikachuBatch = ds.next();
Assertions.assertAlmostEquals(pikachuBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(pikachu, 100, 100)).expandDims(0));
Assert.assertEquals(pikachuBatch.getLabels().singletonOrThrow(), manager.create(new long[] { 2 }));
pikachuBatch.close();
}
}
}
use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class ImageClassificationDataset method makeTranslator.
/**
* Returns the {@link ImageClassificationTranslator} matching the format of this dataset.
*
* @return the {@link ImageClassificationTranslator} matching the format of this dataset
*/
public Translator<Image, Classifications> makeTranslator() {
Pipeline pipeline = new Pipeline();
// Resize the image if the image size is fixed
Optional<Integer> width = getImageWidth();
Optional<Integer> height = getImageHeight();
if (width.isPresent() && height.isPresent()) {
pipeline.add(new Resize(width.get(), height.get()));
}
pipeline.add(new ToTensor());
return ImageClassificationTranslator.builder().optSynset(getClasses()).setPipeline(pipeline).build();
}
use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class DlrTest method testDlr.
@Test
public void testDlr() throws ModelException, IOException, TranslateException {
TestRequirements.notWindows();
TestRequirements.notArm();
String os;
if (System.getProperty("os.name").toLowerCase().startsWith("mac")) {
os = "osx";
} else if (System.getProperty("os.name").toLowerCase().startsWith("linux")) {
os = "linux";
} else {
throw new IllegalStateException("Unexpected os");
}
ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new Resize(224, 224)).addTransform(new ToTensor()).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optApplication(Application.CV.IMAGE_CLASSIFICATION).optFilter("layers", "50").optFilter("os", os).optTranslator(translator).optEngine("DLR").optProgress(new ProgressBar()).build();
Path file = Paths.get("../../../examples/src/test/resources/kitten.jpg");
Image image = ImageFactory.getInstance().fromFile(file);
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Classifications result = predictor.predict(image);
Assert.assertEquals(result.best().getClassName(), "n02123045 tabby, tabby cat");
}
}
use of ai.djl.modality.cv.transform.ToTensor in project djl by deepjavalibrary.
the class TrainWithOptimizers method getDataset.
private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
Pipeline pipeline = new Pipeline(new ToTensor(), new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD));
Cifar10 cifar10 = Cifar10.builder().optUsage(usage).setSampling(arguments.getBatchSize(), true).optLimit(arguments.getLimit()).optPipeline(pipeline).build();
cifar10.prepare(new ProgressBar());
return cifar10;
}
Aggregations