Search in sources :

Example 36 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class TrainingTester method train.

private List<StepRecord> train(@Nonnull NotebookOutput log, @Nonnull BiFunction<NotebookOutput, Trainable, List<StepRecord>> opt, @Nonnull Layer layer, @Nonnull Tensor[][] data, @Nonnull boolean... mask) {
    try {
        int inputs = data[0].length;
        @Nonnull final PipelineNetwork network = new PipelineNetwork(inputs);
        network.wrap(new MeanSqLossLayer(), network.add(layer, IntStream.range(0, inputs - 1).mapToObj(i -> network.getInput(i)).toArray(i -> new DAGNode[i])), network.getInput(inputs - 1));
        @Nonnull ArrayTrainable trainable = new ArrayTrainable(data, network);
        if (0 < mask.length)
            trainable.setMask(mask);
        List<StepRecord> history;
        try {
            history = opt.apply(log, trainable);
            if (history.stream().mapToDouble(x -> x.fitness).min().orElse(1) > 1e-5) {
                if (!network.isFrozen()) {
                    log.p("This training apply resulted in the following configuration:");
                    log.code(() -> {
                        return network.state().stream().map(Arrays::toString).reduce((a, b) -> a + "\n" + b).orElse("");
                    });
                }
                if (0 < mask.length) {
                    log.p("And regressed input:");
                    log.code(() -> {
                        return Arrays.stream(data).flatMap(x -> Arrays.stream(x)).limit(1).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
                    });
                }
                log.p("To produce the following output:");
                log.code(() -> {
                    Result[] array = ConstantResult.batchResultArray(pop(data));
                    @Nullable Result eval = layer.eval(array);
                    for (@Nonnull Result result : array) {
                        result.freeRef();
                        result.getData().freeRef();
                    }
                    TensorList tensorList = eval.getData();
                    eval.freeRef();
                    String str = tensorList.stream().limit(1).map(x -> {
                        String s = x.prettyPrint();
                        x.freeRef();
                        return s;
                    }).reduce((a, b) -> a + "\n" + b).orElse("");
                    tensorList.freeRef();
                    return str;
                });
            } else {
                log.p("Training Converged");
            }
        } finally {
            trainable.freeRef();
            network.freeRef();
        }
        return history;
    } finally {
        layer.freeRef();
        for (@Nonnull Tensor[] tensors : data) {
            for (@Nonnull Tensor tensor : tensors) {
                tensor.freeRef();
            }
        }
    }
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) StepRecord(com.simiacryptus.mindseye.test.StepRecord) Arrays(java.util.Arrays) Nullable(javax.annotation.Nullable)

Example 37 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class CudaLayerTester method testNonstandardBounds.

/**
 * Test nonstandard bounds tolerance statistics.
 *
 * @param log            the log
 * @param reference      the reference
 * @param inputPrototype the input prototype
 * @return the tolerance statistics
 */
@Nonnull
public ToleranceStatistics testNonstandardBounds(final NotebookOutput log, @Nullable final Layer reference, @Nonnull final Tensor[] inputPrototype) {
    log.h2("Irregular Input");
    log.p("This layer should be able to accept non-dense inputs.");
    return log.code(() -> {
        Tensor[] randomized = Arrays.stream(inputPrototype).map(x -> x.map(v -> getRandom())).toArray(i -> new Tensor[i]);
        logger.info("Input: " + Arrays.stream(randomized).map(Tensor::prettyPrint).collect(Collectors.toList()));
        Precision precision = Precision.Double;
        TensorList[] controlInput = CudaSystem.run(gpu -> {
            return Arrays.stream(randomized).map(original -> {
                TensorArray data = TensorArray.create(original);
                CudaTensorList wrap = CudaTensorList.wrap(gpu.getTensor(data, precision, MemoryType.Managed, false), 1, original.getDimensions(), precision);
                data.freeRef();
                return wrap;
            }).toArray(i -> new TensorList[i]);
        }, 0);
        @Nonnull final SimpleResult controlResult = CudaSystem.run(gpu -> {
            return SimpleGpuEval.run(reference, gpu, controlInput);
        }, 1);
        final TensorList[] irregularInput = CudaSystem.run(gpu -> {
            return Arrays.stream(randomized).map(original -> {
                return buildIrregularCudaTensor(gpu, precision, original);
            }).toArray(i -> new TensorList[i]);
        }, 0);
        @Nonnull final SimpleResult testResult = CudaSystem.run(gpu -> {
            return SimpleGpuEval.run(reference, gpu, irregularInput);
        }, 1);
        try {
            ToleranceStatistics compareOutput = compareOutput(controlResult, testResult);
            ToleranceStatistics compareDerivatives = compareDerivatives(controlResult, testResult);
            return compareDerivatives.combine(compareOutput);
        } finally {
            Arrays.stream(randomized).forEach(ReferenceCountingBase::freeRef);
            Arrays.stream(controlInput).forEach(ReferenceCounting::freeRef);
            Arrays.stream(irregularInput).forEach(x -> x.freeRef());
            controlResult.freeRef();
            testResult.freeRef();
        }
    });
}
Also used : IntStream(java.util.stream.IntStream) SimpleResult(com.simiacryptus.mindseye.test.SimpleResult) SimpleGpuEval(com.simiacryptus.mindseye.test.SimpleGpuEval) Arrays(java.util.Arrays) CudaMemory(com.simiacryptus.mindseye.lang.cudnn.CudaMemory) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) Random(java.util.Random) Function(java.util.function.Function) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) CudnnHandle(com.simiacryptus.mindseye.lang.cudnn.CudnnHandle) Layer(com.simiacryptus.mindseye.lang.Layer) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) IntFunction(java.util.function.IntFunction) Logger(org.slf4j.Logger) CudaDevice(com.simiacryptus.mindseye.lang.cudnn.CudaDevice) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) Collectors(java.util.stream.Collectors) Stream(java.util.stream.Stream) CudaSystem(com.simiacryptus.mindseye.lang.cudnn.CudaSystem) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) MemoryType(com.simiacryptus.mindseye.lang.cudnn.MemoryType) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) Tensor(com.simiacryptus.mindseye.lang.Tensor) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) Nonnull(javax.annotation.Nonnull) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleResult(com.simiacryptus.mindseye.test.SimpleResult) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nonnull(javax.annotation.Nonnull)

Example 38 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class NotebookReportBase method run.

/**
 * Run.
 *
 * @param fn      the fn
 * @param logPath the log path
 */
public void run(@Nonnull Consumer<NotebookOutput> fn, @Nonnull CharSequence... logPath) {
    try (@Nonnull NotebookOutput log = getLog(logPath.length == 0 ? new String[] { getClass().getSimpleName() } : logPath)) {
        printHeader(log);
        @Nonnull TimedResult<Void> time = TimedResult.time(() -> {
            try {
                fn.accept(log);
                log.setFrontMatterProperty("result", "OK");
            } catch (Throwable e) {
                log.setFrontMatterProperty("result", getExceptionString(e).toString().replaceAll("\n", "<br/>").trim());
                throw (RuntimeException) (e instanceof RuntimeException ? e : new RuntimeException(e));
            }
        });
        log.setFrontMatterProperty("execution_time", String.format("%.6f", time.timeNanos / 1e9));
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
}
Also used : Nonnull(javax.annotation.Nonnull) MarkdownNotebookOutput(com.simiacryptus.util.io.MarkdownNotebookOutput) HtmlNotebookOutput(com.simiacryptus.util.io.HtmlNotebookOutput) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) IOException(java.io.IOException)

Example 39 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class TestUtil method printHistory.

/**
 * Print history.
 *
 * @param log     the log
 * @param history the history
 */
public static void printHistory(@Nonnull final NotebookOutput log, @Nonnull final List<StepRecord> history) {
    if (!history.isEmpty()) {
        log.out("Convergence Plot: ");
        log.code(() -> {
            final DoubleSummaryStatistics valueStats = history.stream().mapToDouble(x -> x.fitness).filter(x -> x > 0).summaryStatistics();
            @Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[] { step.iteration, Math.log10(Math.max(valueStats.getMin(), step.fitness)) }).toArray(i -> new double[i][]));
            plot.setTitle("Convergence Plot");
            plot.setAxisLabels("Iteration", "log10(Fitness)");
            plot.setSize(600, 400);
            return plot;
        });
    }
}
Also used : Arrays(java.util.Arrays) ScheduledFuture(java.util.concurrent.ScheduledFuture) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) IntUnaryOperator(java.util.function.IntUnaryOperator) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) URI(java.net.URI) Graph(guru.nidi.graphviz.model.Graph) LongToIntFunction(java.util.function.LongToIntFunction) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) BufferedImage(java.awt.image.BufferedImage) UUID(java.util.UUID) ComponentEvent(java.awt.event.ComponentEvent) WindowAdapter(java.awt.event.WindowAdapter) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Collectors(java.util.stream.Collectors) WindowEvent(java.awt.event.WindowEvent) Executors(java.util.concurrent.Executors) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) IntStream(java.util.stream.IntStream) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ActionListener(java.awt.event.ActionListener) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) LinkSource(guru.nidi.graphviz.model.LinkSource) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Supplier(java.util.function.Supplier) JsonUtil(com.simiacryptus.util.io.JsonUtil) MutableNode(guru.nidi.graphviz.model.MutableNode) Charset(java.nio.charset.Charset) Factory(guru.nidi.graphviz.model.Factory) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) WeakReference(java.lang.ref.WeakReference) LinkTarget(guru.nidi.graphviz.model.LinkTarget) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) LongSummaryStatistics(java.util.LongSummaryStatistics) PrintStream(java.io.PrintStream) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) RankDir(guru.nidi.graphviz.attribute.RankDir) IOException(java.io.IOException) FileFilter(javax.swing.filechooser.FileFilter) ActionEvent(java.awt.event.ActionEvent) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics) File(java.io.File) java.awt(java.awt) ComponentAdapter(java.awt.event.ComponentAdapter) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) MonitoredObject(com.simiacryptus.util.MonitoredObject) IntToLongFunction(java.util.function.IntToLongFunction) Link(guru.nidi.graphviz.model.Link) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) javax.swing(javax.swing) Nonnull(javax.annotation.Nonnull) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) PlotCanvas(smile.plot.PlotCanvas)

Example 40 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class ReferenceIO method test.

@Nullable
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer layer, @Nonnull final Tensor... inputPrototype) {
    if (!referenceIO.isEmpty()) {
        log.h1("Reference Input/Output Pairs");
        log.p("Display pre-setBytes input/output example pairs:");
        referenceIO.forEach((input, output) -> {
            log.code(() -> {
                @Nonnull final SimpleEval eval = SimpleEval.run(layer, input);
                Tensor add = output.scale(-1).addAndFree(eval.getOutput());
                @Nonnull final DoubleStatistics error = new DoubleStatistics().accept(add.getData());
                add.freeRef();
                String format = String.format("--------------------\nInput: \n[%s]\n--------------------\nOutput: \n%s\nError: %s\n--------------------\nDerivative: \n%s", Arrays.stream(input).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get(), eval.getOutput().prettyPrint(), error, Arrays.stream(eval.getDerivative()).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get());
                eval.freeRef();
                return format;
            });
        });
    } else {
        log.h1("Example Input/Output Pair");
        log.p("Display input/output pairs from random executions:");
        log.code(() -> {
            @Nonnull final SimpleEval eval = SimpleEval.run(layer, inputPrototype);
            String format = String.format("--------------------\nInput: \n[%s]\n--------------------\nOutput: \n%s\n--------------------\nDerivative: \n%s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).orElse(""), eval.getOutput().prettyPrint(), Arrays.stream(eval.getDerivative()).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).orElse(""));
            eval.freeRef();
            return format;
        });
    }
    return null;
}
Also used : Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) Layer(com.simiacryptus.mindseye.lang.Layer) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) Nullable(javax.annotation.Nullable)

Aggregations

NotebookOutput (com.simiacryptus.util.io.NotebookOutput)48 Nonnull (javax.annotation.Nonnull)48 Tensor (com.simiacryptus.mindseye.lang.Tensor)46 Nullable (javax.annotation.Nullable)40 Layer (com.simiacryptus.mindseye.lang.Layer)39 Arrays (java.util.Arrays)38 List (java.util.List)37 IntStream (java.util.stream.IntStream)31 TestUtil (com.simiacryptus.mindseye.test.TestUtil)25 Logger (org.slf4j.Logger)25 LoggerFactory (org.slf4j.LoggerFactory)25 Stream (java.util.stream.Stream)23 Collectors (java.util.stream.Collectors)22 ArrayList (java.util.ArrayList)21 HashMap (java.util.HashMap)21 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)20 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)19 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 TimeUnit (java.util.concurrent.TimeUnit)19 StepRecord (com.simiacryptus.mindseye.test.StepRecord)18