Search in sources :

Example 16 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class IterativeTrainer method measure.

/**
 * Measure point sample.
 *
 * @param reset the reset
 * @return the point sample
 */
@Nullable
public PointSample measure(boolean reset) {
    @Nullable PointSample currentPoint = null;
    int retries = 0;
    do {
        if (reset) {
            orientation.reset();
            if (subject.getLayer() instanceof DAGNetwork) {
                ((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
                    if (layer instanceof StochasticComponent)
                        ((StochasticComponent) layer).shuffle(StochasticComponent.random.get().nextLong());
                });
            }
            if (!subject.reseed(System.nanoTime())) {
                if (retries > 0)
                    throw new IterativeStopException("Failed to reset training subject");
            } else {
                monitor.log(String.format("Reset training subject"));
            }
        }
        if (null != currentPoint) {
            currentPoint.freeRef();
        }
        currentPoint = subject.measure(monitor);
    } while (!Double.isFinite(currentPoint.getMean()) && 10 < retries++);
    if (!Double.isFinite(currentPoint.getMean())) {
        currentPoint.freeRef();
        throw new IterativeStopException();
    }
    return currentPoint;
}
Also used : StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) PointSample(com.simiacryptus.mindseye.lang.PointSample) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Example 17 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class IterativeTrainer method run.

/**
 * Run double.
 *
 * @return the double
 */
public double run() {
    final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
    long lastIterationTime = System.nanoTime();
    @Nullable PointSample currentPoint = measure(true);
    mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.getMean() > terminateThreshold) {
        if (currentIteration.get() > maxIterations) {
            break;
        }
        currentPoint.freeRef();
        currentPoint = measure(true);
        assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
        subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample || iterationsPerSample <= 0; subiteration++) {
            if (timeoutMs < System.currentTimeMillis()) {
                break mainLoop;
            }
            if (currentIteration.incrementAndGet() > maxIterations) {
                break mainLoop;
            }
            currentPoint.freeRef();
            currentPoint = measure(true);
            @Nullable final PointSample _currentPoint = currentPoint;
            @Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> orientation.orient(subject, _currentPoint, monitor));
            final LineSearchCursor direction = timedOrientation.result;
            final CharSequence directionType = direction.getDirectionType();
            @Nullable final PointSample previous = currentPoint;
            previous.addRef();
            try {
                @Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> step(direction, directionType, previous));
                currentPoint.freeRef();
                currentPoint = timedLineSearch.result;
                final long now = System.nanoTime();
                final CharSequence perfString = String.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f", (now - lastIterationTime) / 1e9, timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9);
                lastIterationTime = now;
                monitor.log(String.format("Fitness changed from %s to %s", previous.getMean(), currentPoint.getMean()));
                if (previous.getMean() <= currentPoint.getMean()) {
                    if (previous.getMean() < currentPoint.getMean()) {
                        monitor.log(String.format("Resetting Iteration %s", perfString));
                        currentPoint.freeRef();
                        currentPoint = direction.step(0, monitor).point;
                    } else {
                        monitor.log(String.format("Static Iteration %s", perfString));
                    }
                    if (subject.reseed(System.nanoTime())) {
                        monitor.log(String.format("Iteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.getMean()));
                        monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
                        break subiterationLoop;
                    } else {
                        monitor.log(String.format("Iteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.getMean()));
                        monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
                        break mainLoop;
                    }
                } else {
                    monitor.log(String.format("Iteration %s complete. Error: %s " + perfString, currentIteration.get(), currentPoint.getMean()));
                }
                monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
            } finally {
                previous.freeRef();
                direction.freeRef();
            }
        }
    }
    if (subject.getLayer() instanceof DAGNetwork) {
        ((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
            if (layer instanceof StochasticComponent)
                ((StochasticComponent) layer).clearNoise();
        });
    }
    double mean = null == currentPoint ? Double.NaN : currentPoint.getMean();
    currentPoint.freeRef();
    return mean;
}
Also used : Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable)

Example 18 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class MnistTestBase method run.

/**
 * Run.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    @Nonnull final List<Step> history = new ArrayList<>();
    @Nonnull final MonitoredObject monitoringRoot = new MonitoredObject();
    @Nonnull final TrainingMonitor monitor = getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    final DAGNetwork network = buildModel(log);
    addMonitoring(network, monitoringRoot);
    log.h1("Training");
    train(log, network, trainingData, monitor);
    report(log, monitoringRoot, history, network);
    validate(log, network);
    removeMonitoring(network);
}
Also used : Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) MonitoredObject(com.simiacryptus.util.MonitoredObject) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork)

Example 19 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class ImageClassifierTestBase method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    Future<Tensor[][]> submit = Executors.newSingleThreadExecutor().submit(() -> Arrays.stream(EncodingUtil.getImages(log, img -> {
        return img;
    // return TestUtil.resize(img, 224, 224);
    // if(img.getWidth()>img.getHeight()) {
    // return TestUtil.resize(img, 224, img.getHeight() * 224 / img.getWidth());
    // } else {
    // return TestUtil.resize(img, img.getWidth() * 224 / img.getHeight(), 224);
    // }
    }, 10, new CharSequence[] {})).toArray(i -> new Tensor[i][]));
    ImageClassifier vgg16 = getImageClassifier(log);
    @Nonnull Layer network = vgg16.getNetwork();
    log.h1("Network Diagram");
    log.p("This is a diagram of the imported network:");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph((DAGNetwork) network)).height(4000).width(800).render(Format.PNG).toImage();
    });
    // @javax.annotation.Nonnull SerializationTest serializationTest = new SerializationTest();
    // serializationTest.setPersist(true);
    // serializationTest.test(log, network, (Tensor[]) null);
    log.h1("Predictions");
    Tensor[][] images;
    try {
        images = submit.get();
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    @Nonnull Map<CharSequence, List<LinkedHashMap<CharSequence, Double>>> modelPredictions = new HashMap<>();
    modelPredictions.put("Source", predict(log, vgg16, network, images));
    network.freeRef();
    // serializationTest.getModels().forEach((precision, model) -> {
    // log.h2(precision.name());
    // modelPredictions.put(precision.name(), predict(log, vgg16, model, images));
    // });
    log.h1("Result");
    log.code(() -> {
        @Nonnull TableOutput tableOutput = new TableOutput();
        for (int i = 0; i < images.length; i++) {
            int index = i;
            @Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
            row.put("Image", log.image(images[i][1].toImage(), ""));
            modelPredictions.forEach((model, predictions) -> {
                row.put(model, predictions.get(index).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
            });
            tableOutput.putRow(row);
        }
        return tableOutput;
    }, 256 * 1024);
// log.p("CudaSystem Statistics:");
// log.code(() -> {
// return TestUtil.toFormattedJson(CudaSystem.getExecutionStatistics());
// });
}
Also used : Graphviz(guru.nidi.graphviz.engine.Graphviz) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) Test(org.junit.Test) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Executors(java.util.concurrent.Executors) LinkedHashMap(java.util.LinkedHashMap) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List) Future(java.util.concurrent.Future) Format(guru.nidi.graphviz.engine.Format) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) EncodingUtil(com.simiacryptus.mindseye.labs.encoding.EncodingUtil) Nonnull(javax.annotation.Nonnull) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Layer(com.simiacryptus.mindseye.lang.Layer) TableOutput(com.simiacryptus.util.TableOutput) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List)

Aggregations

DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)19 Nonnull (javax.annotation.Nonnull)16 Layer (com.simiacryptus.mindseye.lang.Layer)11 Nullable (javax.annotation.Nullable)11 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 List (java.util.List)10 ArrayList (java.util.ArrayList)9 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)7 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)7 Map (java.util.Map)7 TimeUnit (java.util.concurrent.TimeUnit)7 DAGNode (com.simiacryptus.mindseye.network.DAGNode)6 TestUtil (com.simiacryptus.mindseye.test.TestUtil)6 IntStream (java.util.stream.IntStream)6 Format (guru.nidi.graphviz.engine.Format)5 Graphviz (guru.nidi.graphviz.engine.Graphviz)5 File (java.io.File)5 IOException (java.io.IOException)5 HashMap (java.util.HashMap)5