use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class CanaryTest method testDlr.
private static void testDlr() throws ModelException, IOException, TranslateException {
String os;
if (System.getProperty("os.name").toLowerCase().startsWith("mac")) {
os = "osx";
} else if (System.getProperty("os.name").toLowerCase().startsWith("linux")) {
os = "linux";
} else {
throw new AssertionError("DLR only work on mac and Linux.");
}
ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new Resize(224, 224)).addTransform(new ToTensor()).build();
Criteria<Image, Classifications> criteria = Criteria.builder().setTypes(Image.class, Classifications.class).optApplication(Application.CV.IMAGE_CLASSIFICATION).optFilter("layers", "50").optFilter("os", os).optTranslator(translator).optEngine("DLR").optProgress(new ProgressBar()).build();
String url = "https://resources.djl.ai/images/kitten.jpg";
Image image = ImageFactory.getInstance().fromUrl(url);
try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Classifications classifications = predictor.predict(image);
logger.info("{}", classifications);
}
}
use of ai.djl.modality.cv.transform.ToTensor in project djl-demo by deepjavalibrary.
the class Inference method main.
public static void main(String[] args) throws ModelException, TranslateException, IOException {
// the location where the model is saved
Path modelDir = Paths.get("models");
// the path of image to classify
String imageFilePath;
if (args.length == 0) {
imageFilePath = "ut-zap50k-images-square/Sandals/Heel/Annie/7350693.3.jpg";
} else {
imageFilePath = args[0];
}
// Load the image file from the path
Image img = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));
try (Model model = Models.getModel()) {
// empty model instance
// load the model
model.load(modelDir, Models.MODEL_NAME);
// define a translator for pre and post processing
// out of the box this translator converts images to ResNet friendly ResNet 18 shape
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT)).addTransform(new ToTensor()).optApplySoftmax(true).build();
// run the inference using a Predictor
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
// holds the probability score per label
Classifications predictResult = predictor.predict(img);
System.out.println(predictResult);
}
}
}
use of ai.djl.modality.cv.transform.ToTensor in project PissAI by DxsSucuk.
the class PissAI method checkImage.
public JsonObject checkImage(Image image, Model model, List<String> classes) throws TranslateException {
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder().addTransform(new Resize(256, 256)).addTransform(new ToTensor()).optSynset(classes).optApplySoftmax(true).build();
Predictor<Image, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(image);
JsonElement jsonElement = JsonParser.parseString(classifications.toJson());
if (jsonElement.isJsonArray()) {
JsonArray jsonArray = jsonElement.getAsJsonArray();
JsonObject detected = new JsonObject();
float highestValue = 0.0f;
for (int i = 0; i < jsonArray.size(); i++) {
JsonElement jsonElement1 = jsonArray.get(i);
if (jsonElement1.isJsonObject()) {
JsonObject jsonObject = jsonElement1.getAsJsonObject();
String name = jsonObject.get("className").getAsString();
float currentValue = jsonObject.get("probability").getAsFloat();
System.out.println(name + " - " + Math.round(currentValue * 100) + "%");
if (highestValue < currentValue) {
highestValue = currentValue;
detected = jsonObject;
}
}
}
return detected;
}
return new JsonObject();
}
use of ai.djl.modality.cv.transform.ToTensor in project PissAI by DxsSucuk.
the class PissAI method runTest.
public void runTest(Model model, List<String> classes, boolean valid) throws IOException, TranslateException {
String validImage = "https://upload.wikimedia.org/wikipedia/en/thumb/1/1d/Dream_icon.svg/1200px-Dream_icon.svg.png";
String invalidImage = "https://upload.wikimedia.org/wikipedia/commons/thumb/2/25/Red.svg/2048px-Red.svg.png";
String imageUrl = valid ? validImage : invalidImage;
Image imageToCheck = ImageFactory.getInstance().fromUrl(imageUrl);
imageToCheck = ImageFactory.getInstance().fromNDArray(imageToCheck.toNDArray(NDManager.newBaseManager()).squeeze());
Object wrappedImage = imageToCheck.getWrappedImage();
Translator<Image, Float> translator = BinaryImageTranslator.builder().addTransform(new Resize(256, 256)).addTransform(new ToTensor()).addTransform(NDArray::squeeze).optApplySoftmax(true).build();
Predictor<Image, Float> predictor = model.newPredictor(translator);
float classifications = predictor.predict(imageToCheck);
/*JsonElement jsonElement = JsonParser.parseString(classifications.toJson());
if (jsonElement.isJsonArray()) {
JsonArray jsonArray = jsonElement.getAsJsonArray();
String detected = "";
float highestValue = 0.0f;
for (int i = 0; i < jsonArray.size(); i++) {
JsonElement jsonElement1 = jsonArray.get(i);
if (jsonElement1.isJsonObject()) {
JsonObject jsonObject = jsonElement1.getAsJsonObject();
String name = jsonObject.get("className").getAsString();
float currentValue = jsonObject.get("probability").getAsFloat();
System.out.println(name + " - " + Math.round(currentValue * 100) + "%");
if (highestValue < currentValue) {
highestValue = currentValue;
detected = name;
}
}
}
}*/
System.out.println("It is most likely dream, about " + Math.round(classifications * 100) + "%");
}
use of ai.djl.modality.cv.transform.ToTensor in project PissAI by DxsSucuk.
the class ModelTrainer method createDataSet.
public ImageFolder createDataSet() throws TranslateException, IOException {
int batchSize = 32;
// set the image folder path
Repository repository = Repository.newInstance("folder", Paths.get("imagefolder"));
ImageFolder dataset = ImageFolder.builder().setRepository(repository).addTransform(new Resize(256, 256)).addTransform(new ToTensor()).setSampling(batchSize, true).build();
// call prepare before using
dataset.prepare();
return dataset;
}
Aggregations