Search in sources :

Example 46 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class ImageClassifierTestBase method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    Future<Tensor[][]> submit = Executors.newSingleThreadExecutor().submit(() -> Arrays.stream(EncodingUtil.getImages(log, img -> {
        return img;
    // return TestUtil.resize(img, 224, 224);
    // if(img.getWidth()>img.getHeight()) {
    // return TestUtil.resize(img, 224, img.getHeight() * 224 / img.getWidth());
    // } else {
    // return TestUtil.resize(img, img.getWidth() * 224 / img.getHeight(), 224);
    // }
    }, 10, new CharSequence[] {})).toArray(i -> new Tensor[i][]));
    ImageClassifier vgg16 = getImageClassifier(log);
    @Nonnull Layer network = vgg16.getNetwork();
    log.h1("Network Diagram");
    log.p("This is a diagram of the imported network:");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph((DAGNetwork) network)).height(4000).width(800).render(Format.PNG).toImage();
    });
    // @javax.annotation.Nonnull SerializationTest serializationTest = new SerializationTest();
    // serializationTest.setPersist(true);
    // serializationTest.test(log, network, (Tensor[]) null);
    log.h1("Predictions");
    Tensor[][] images;
    try {
        images = submit.get();
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    @Nonnull Map<CharSequence, List<LinkedHashMap<CharSequence, Double>>> modelPredictions = new HashMap<>();
    modelPredictions.put("Source", predict(log, vgg16, network, images));
    network.freeRef();
    // serializationTest.getModels().forEach((precision, model) -> {
    // log.h2(precision.name());
    // modelPredictions.put(precision.name(), predict(log, vgg16, model, images));
    // });
    log.h1("Result");
    log.code(() -> {
        @Nonnull TableOutput tableOutput = new TableOutput();
        for (int i = 0; i < images.length; i++) {
            int index = i;
            @Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
            row.put("Image", log.image(images[i][1].toImage(), ""));
            modelPredictions.forEach((model, predictions) -> {
                row.put(model, predictions.get(index).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
            });
            tableOutput.putRow(row);
        }
        return tableOutput;
    }, 256 * 1024);
// log.p("CudaSystem Statistics:");
// log.code(() -> {
// return TestUtil.toFormattedJson(CudaSystem.getExecutionStatistics());
// });
}
Also used : Graphviz(guru.nidi.graphviz.engine.Graphviz) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) Test(org.junit.Test) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Executors(java.util.concurrent.Executors) LinkedHashMap(java.util.LinkedHashMap) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List) Future(java.util.concurrent.Future) Format(guru.nidi.graphviz.engine.Format) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) EncodingUtil(com.simiacryptus.mindseye.labs.encoding.EncodingUtil) Nonnull(javax.annotation.Nonnull) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Layer(com.simiacryptus.mindseye.lang.Layer) TableOutput(com.simiacryptus.util.TableOutput) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List)

Example 47 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class CudaLayerTestBase method getReferenceIOTester.

@Nullable
@Override
protected ComponentTest<ToleranceStatistics> getReferenceIOTester() {
    @Nullable final ComponentTest<ToleranceStatistics> inner = super.getReferenceIOTester();
    return new ComponentTestBase<ToleranceStatistics>() {

        @Override
        protected void _free() {
            inner.freeRef();
            super._free();
        }

        @Override
        public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
            @Nullable PrintStream apiLog = null;
            try {
                @Nonnull String logName = "cuda_" + log.getName() + "_io.log";
                log.p(log.file((String) null, logName, "GPU Log"));
                apiLog = new PrintStream(log.file(logName));
                CudaSystem.addLog(apiLog);
                return inner.test(log, component, inputPrototype);
            } finally {
                if (null != apiLog) {
                    apiLog.close();
                    CudaSystem.apiLog.remove(apiLog);
                }
            }
        }
    };
}
Also used : ComponentTestBase(com.simiacryptus.mindseye.test.unit.ComponentTestBase) PrintStream(java.io.PrintStream) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Example 48 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class CudaLayerTestBase method getPerformanceTester.

@Nullable
@Override
public ComponentTest<ToleranceStatistics> getPerformanceTester() {
    @Nullable ComponentTest<ToleranceStatistics> inner = new PerformanceTester().setBatches(testingBatchSize);
    return new ComponentTestBase<ToleranceStatistics>() {

        @Override
        protected void _free() {
            inner.freeRef();
            super._free();
        }

        @Override
        public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
            @Nullable PrintStream apiLog = null;
            try {
                @Nonnull String logName = "cuda_" + log.getName() + "_perf.log";
                log.p(log.file((String) null, logName, "GPU Log"));
                apiLog = new PrintStream(log.file(logName));
                CudaSystem.addLog(apiLog);
                return inner.test(log, component, inputPrototype);
            } finally {
                if (null != apiLog) {
                    apiLog.close();
                    CudaSystem.apiLog.remove(apiLog);
                }
            }
        }
    };
}
Also used : ComponentTestBase(com.simiacryptus.mindseye.test.unit.ComponentTestBase) PrintStream(java.io.PrintStream) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) PerformanceTester(com.simiacryptus.mindseye.test.unit.PerformanceTester) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Aggregations

NotebookOutput (com.simiacryptus.util.io.NotebookOutput)48 Nonnull (javax.annotation.Nonnull)48 Tensor (com.simiacryptus.mindseye.lang.Tensor)46 Nullable (javax.annotation.Nullable)40 Layer (com.simiacryptus.mindseye.lang.Layer)39 Arrays (java.util.Arrays)38 List (java.util.List)37 IntStream (java.util.stream.IntStream)31 TestUtil (com.simiacryptus.mindseye.test.TestUtil)25 Logger (org.slf4j.Logger)25 LoggerFactory (org.slf4j.LoggerFactory)25 Stream (java.util.stream.Stream)23 Collectors (java.util.stream.Collectors)22 ArrayList (java.util.ArrayList)21 HashMap (java.util.HashMap)21 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)20 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)19 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 TimeUnit (java.util.concurrent.TimeUnit)19 StepRecord (com.simiacryptus.mindseye.test.StepRecord)18