Search in sources :

Example 1 with ImageClassifier

use of com.simiacryptus.mindseye.applications.ImageClassifier in project MindsEye by SimiaCryptus.

the class VGG16_HDF5_Test method getImageClassifier.

@Override
public ImageClassifier getImageClassifier(@Nonnull NotebookOutput log) {
    @Nonnull PrintStream apiLog = new PrintStream(log.file("cuda.log"));
    CudaSystem.addLog(apiLog);
    log.p(log.file((String) null, "cuda.log", "GPU Log"));
    return log.code(() -> {
        @Nonnull ImageClassifier vgg16_hdf5 = VGG16.fromHDF5();
        ((HasHDF5) vgg16_hdf5).getHDF5().print();
        return vgg16_hdf5;
    });
}
Also used : PrintStream(java.io.PrintStream) Nonnull(javax.annotation.Nonnull) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier)

Example 2 with ImageClassifier

use of com.simiacryptus.mindseye.applications.ImageClassifier in project MindsEye by SimiaCryptus.

the class ImageClassificationBase method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    log.h1("Model");
    log.p("In this demonstration, we will show how to load an image recognition network and use it to identify object in images.");
    log.p("We start by loading the VGG16 pretrained model using the HD5 importer. This downloads, if needed, the weights from a file in S3 and re-constructs the network architecture by custom code.");
    log.p("Next, we need an example image to analyze:");
    log.p("We pass this image to the categorization network, and get the following top-10 results. Note that multiple objects may be detected, and the total percentage may be greater than 100%.");
    log.p("Once we have categories identified, we can attempt to localize each object category within the image. We do this via a pipeline starting with the backpropagated input signal delta and applying several filters e.g. blurring and normalization to produce an alpha channel. When applied to the input image, we highlight the image areas related to the object type in question. Note that this produces a fuzzy blob, which does indicate object location but is a poor indicator of object boundaries. Below we perform this task for the top 5 object categories:");
    ImageClassifier vgg16 = loadModel(log);
    log.h1("Data");
    Tensor[] images = loadData(log);
    log.h1("Prediction");
    List<LinkedHashMap<CharSequence, Double>> predictions = log.code(() -> {
        return vgg16.predict(5, images);
    });
    log.h1("Results");
    log.code(() -> {
        @Nonnull TableOutput tableOutput = new TableOutput();
        for (int i = 0; i < images.length; i++) {
            @Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
            row.put("Image", log.image(images[i].toImage(), ""));
            row.put("Prediction", predictions.get(i).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
            tableOutput.putRow(row);
        }
        return tableOutput;
    }, 256 * 1024);
    log.setFrontMatterProperty("status", "OK");
}
Also used : BufferedImage(java.awt.image.BufferedImage) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Caltech101(com.simiacryptus.mindseye.test.data.Caltech101) LinkedHashMap(java.util.LinkedHashMap) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) LinkedHashMap(java.util.LinkedHashMap) TableOutput(com.simiacryptus.util.TableOutput) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier)

Example 3 with ImageClassifier

use of com.simiacryptus.mindseye.applications.ImageClassifier in project MindsEye by SimiaCryptus.

the class ImageClassifierTestBase method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    Future<Tensor[][]> submit = Executors.newSingleThreadExecutor().submit(() -> Arrays.stream(EncodingUtil.getImages(log, img -> {
        return img;
    // return TestUtil.resize(img, 224, 224);
    // if(img.getWidth()>img.getHeight()) {
    // return TestUtil.resize(img, 224, img.getHeight() * 224 / img.getWidth());
    // } else {
    // return TestUtil.resize(img, img.getWidth() * 224 / img.getHeight(), 224);
    // }
    }, 10, new CharSequence[] {})).toArray(i -> new Tensor[i][]));
    ImageClassifier vgg16 = getImageClassifier(log);
    @Nonnull Layer network = vgg16.getNetwork();
    log.h1("Network Diagram");
    log.p("This is a diagram of the imported network:");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph((DAGNetwork) network)).height(4000).width(800).render(Format.PNG).toImage();
    });
    // @javax.annotation.Nonnull SerializationTest serializationTest = new SerializationTest();
    // serializationTest.setPersist(true);
    // serializationTest.test(log, network, (Tensor[]) null);
    log.h1("Predictions");
    Tensor[][] images;
    try {
        images = submit.get();
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    @Nonnull Map<CharSequence, List<LinkedHashMap<CharSequence, Double>>> modelPredictions = new HashMap<>();
    modelPredictions.put("Source", predict(log, vgg16, network, images));
    network.freeRef();
    // serializationTest.getModels().forEach((precision, model) -> {
    // log.h2(precision.name());
    // modelPredictions.put(precision.name(), predict(log, vgg16, model, images));
    // });
    log.h1("Result");
    log.code(() -> {
        @Nonnull TableOutput tableOutput = new TableOutput();
        for (int i = 0; i < images.length; i++) {
            int index = i;
            @Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
            row.put("Image", log.image(images[i][1].toImage(), ""));
            modelPredictions.forEach((model, predictions) -> {
                row.put(model, predictions.get(index).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
            });
            tableOutput.putRow(row);
        }
        return tableOutput;
    }, 256 * 1024);
// log.p("CudaSystem Statistics:");
// log.code(() -> {
// return TestUtil.toFormattedJson(CudaSystem.getExecutionStatistics());
// });
}
Also used : Graphviz(guru.nidi.graphviz.engine.Graphviz) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) Test(org.junit.Test) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Executors(java.util.concurrent.Executors) LinkedHashMap(java.util.LinkedHashMap) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List) Future(java.util.concurrent.Future) Format(guru.nidi.graphviz.engine.Format) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) EncodingUtil(com.simiacryptus.mindseye.labs.encoding.EncodingUtil) Nonnull(javax.annotation.Nonnull) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Layer(com.simiacryptus.mindseye.lang.Layer) TableOutput(com.simiacryptus.util.TableOutput) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List)

Aggregations

ImageClassifier (com.simiacryptus.mindseye.applications.ImageClassifier)3 Nonnull (javax.annotation.Nonnull)3 Tensor (com.simiacryptus.mindseye.lang.Tensor)2 TestUtil (com.simiacryptus.mindseye.test.TestUtil)2 TableOutput (com.simiacryptus.util.TableOutput)2 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)2 HashMap (java.util.HashMap)2 LinkedHashMap (java.util.LinkedHashMap)2 List (java.util.List)2 EncodingUtil (com.simiacryptus.mindseye.labs.encoding.EncodingUtil)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 NotebookReportBase (com.simiacryptus.mindseye.test.NotebookReportBase)1 Caltech101 (com.simiacryptus.mindseye.test.data.Caltech101)1 Format (guru.nidi.graphviz.engine.Format)1 Graphviz (guru.nidi.graphviz.engine.Graphviz)1 BufferedImage (java.awt.image.BufferedImage)1 PrintStream (java.io.PrintStream)1 Arrays (java.util.Arrays)1 Comparator (java.util.Comparator)1