Search in sources :

Example 6 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class StandardLayerTests method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull final NotebookOutput log) {
    long seed = (long) (Math.random() * Long.MAX_VALUE);
    int[][] smallDims = getSmallDims(new Random(seed));
    final Layer smallLayer = getLayer(smallDims, new Random(seed));
    int[][] largeDims = getLargeDims(new Random(seed));
    final Layer largeLayer = getLayer(largeDims, new Random(seed));
    try {
        if (smallLayer instanceof DAGNetwork) {
            try {
                log.h1("Network Diagram");
                log.p("This is a network apply the following layout:");
                log.code(() -> {
                    return Graphviz.fromGraph(TestUtil.toGraph((DAGNetwork) smallLayer)).height(400).width(600).render(Format.PNG).toImage();
                });
            } catch (Throwable e) {
                logger.info("Error plotting graph", e);
            }
        } else if (smallLayer instanceof Explodable) {
            try {
                Layer explode = ((Explodable) smallLayer).explode();
                if (explode instanceof DAGNetwork) {
                    log.h1("Exploded Network Diagram");
                    log.p("This is a network apply the following layout:");
                    @Nonnull DAGNetwork network = (DAGNetwork) explode;
                    log.code(() -> {
                        @Nonnull Graphviz graphviz = Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600);
                        @Nonnull File file = new File(log.getResourceDir(), log.getName() + "_network.svg");
                        graphviz.render(Format.SVG_STANDALONE).toFile(file);
                        log.link(file, "Saved to File");
                        return graphviz.render(Format.SVG).toString();
                    });
                }
            } catch (Throwable e) {
                logger.info("Error plotting graph", e);
            }
        }
        @Nonnull ArrayList<TestError> exceptions = standardTests(log, seed);
        if (!exceptions.isEmpty()) {
            if (smallLayer instanceof DAGNetwork) {
                for (@Nonnull Invocation invocation : getInvocations(smallLayer, smallDims)) {
                    log.h1("Small SubTests: " + invocation.getLayer().getClass().getSimpleName());
                    log.p(Arrays.deepToString(invocation.getDims()));
                    tests(log, getLittleTests(), invocation, exceptions);
                    invocation.freeRef();
                }
            }
            if (largeLayer instanceof DAGNetwork) {
                testEquivalency = false;
                for (@Nonnull Invocation invocation : getInvocations(largeLayer, largeDims)) {
                    log.h1("Large SubTests: " + invocation.getLayer().getClass().getSimpleName());
                    log.p(Arrays.deepToString(invocation.getDims()));
                    tests(log, getBigTests(), invocation, exceptions);
                    invocation.freeRef();
                }
            }
        }
        log.code(() -> {
            throwException(exceptions);
        });
    } finally {
        smallLayer.freeRef();
        largeLayer.freeRef();
    }
    getFinalTests().stream().filter(x -> null != x).forEach(test -> {
        final Layer perfLayer;
        perfLayer = getLayer(largeDims, new Random(seed));
        perfLayer.assertAlive();
        @Nonnull Layer copy;
        copy = perfLayer.copy();
        Tensor[] randomize = randomize(largeDims);
        try {
            test.test(log, copy, randomize);
        } finally {
            test.freeRef();
            for (@Nonnull Tensor tensor : randomize) {
                tensor.freeRef();
            }
            perfLayer.freeRef();
            copy.freeRef();
        }
    });
}
Also used : JsonObject(com.google.gson.JsonObject) Graphviz(guru.nidi.graphviz.engine.Graphviz) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Format(guru.nidi.graphviz.engine.Format) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) CudaError(com.simiacryptus.mindseye.lang.cudnn.CudaError) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) SysOutInterceptor(com.simiacryptus.util.test.SysOutInterceptor) Collection(java.util.Collection) TestUtil(com.simiacryptus.mindseye.test.TestUtil) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Explodable(com.simiacryptus.mindseye.layers.cudnn.Explodable) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) LifecycleException(com.simiacryptus.mindseye.lang.LifecycleException) Explodable(com.simiacryptus.mindseye.layers.cudnn.Explodable) Graphviz(guru.nidi.graphviz.engine.Graphviz) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Layer(com.simiacryptus.mindseye.lang.Layer) Random(java.util.Random) File(java.io.File)

Example 7 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class TestUtil method extractPerformance.

/**
 * Remove performance wrappers.
 *
 * @param log     the log
 * @param network the network
 */
public static void extractPerformance(@Nonnull final NotebookOutput log, @Nonnull final DAGNetwork network) {
    log.p("Per-layer Performance Metrics:");
    log.code(() -> {
        @Nonnull final Map<CharSequence, MonitoringWrapperLayer> metrics = new HashMap<>();
        network.visitNodes(node -> {
            if (node.getLayer() instanceof MonitoringWrapperLayer) {
                @Nullable final MonitoringWrapperLayer layer = node.getLayer();
                Layer inner = layer.getInner();
                String str = inner.toString();
                str += " class=" + inner.getClass().getName();
                // if(inner instanceof MultiPrecision<?>) {
                // str += "; precision=" + ((MultiPrecision) inner).getPrecision().name();
                // }
                metrics.put(str, layer);
            }
        });
        TestUtil.log.info("Performance: \n\t" + metrics.entrySet().stream().sorted(Comparator.comparing(x -> -x.getValue().getForwardPerformance().getMean())).map(e -> {
            @Nonnull final PercentileStatistics performanceF = e.getValue().getForwardPerformance();
            @Nonnull final PercentileStatistics performanceB = e.getValue().getBackwardPerformance();
            return String.format("%.6fs +- %.6fs (%d) <- %s", performanceF.getMean(), performanceF.getStdDev(), performanceF.getCount(), e.getKey()) + (performanceB.getCount() == 0 ? "" : String.format("%n\tBack: %.6fs +- %.6fs (%s)", performanceB.getMean(), performanceB.getStdDev(), performanceB.getCount()));
        }).reduce((a, b) -> a + "\n\t" + b).get());
    });
    removeInstrumentation(network);
}
Also used : Arrays(java.util.Arrays) ScheduledFuture(java.util.concurrent.ScheduledFuture) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) IntUnaryOperator(java.util.function.IntUnaryOperator) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) URI(java.net.URI) Graph(guru.nidi.graphviz.model.Graph) LongToIntFunction(java.util.function.LongToIntFunction) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) BufferedImage(java.awt.image.BufferedImage) UUID(java.util.UUID) ComponentEvent(java.awt.event.ComponentEvent) WindowAdapter(java.awt.event.WindowAdapter) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Collectors(java.util.stream.Collectors) WindowEvent(java.awt.event.WindowEvent) Executors(java.util.concurrent.Executors) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) IntStream(java.util.stream.IntStream) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ActionListener(java.awt.event.ActionListener) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) LinkSource(guru.nidi.graphviz.model.LinkSource) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Supplier(java.util.function.Supplier) JsonUtil(com.simiacryptus.util.io.JsonUtil) MutableNode(guru.nidi.graphviz.model.MutableNode) Charset(java.nio.charset.Charset) Factory(guru.nidi.graphviz.model.Factory) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) WeakReference(java.lang.ref.WeakReference) LinkTarget(guru.nidi.graphviz.model.LinkTarget) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) LongSummaryStatistics(java.util.LongSummaryStatistics) PrintStream(java.io.PrintStream) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) RankDir(guru.nidi.graphviz.attribute.RankDir) IOException(java.io.IOException) FileFilter(javax.swing.filechooser.FileFilter) ActionEvent(java.awt.event.ActionEvent) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics) File(java.io.File) java.awt(java.awt) ComponentAdapter(java.awt.event.ComponentAdapter) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) MonitoredObject(com.simiacryptus.util.MonitoredObject) IntToLongFunction(java.util.function.IntToLongFunction) Link(guru.nidi.graphviz.model.Link) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) javax.swing(javax.swing) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Nullable(javax.annotation.Nullable) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics)

Example 8 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class AutoencodingProblem method run.

@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
    @Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
    @Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
    echoNetwork.add(fwdNetwork);
    echoNetwork.add(revNetwork);
    @Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
    supervisedNetwork.add(fwdNetwork);
    @Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
    supervisedNetwork.add(dropoutNoiseLayer);
    supervisedNetwork.add(revNetwork);
    supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
    log.h3("Network Diagrams");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final TrainingMonitor monitor = new TrainingMonitor() {

        @Nonnull
        TrainingMonitor inner = TestUtil.getMonitor(history);

        @Override
        public void log(final String msg) {
            inner.log(msg);
        }

        @Override
        public void onStepComplete(final Step currentPoint) {
            dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
            inner.onStepComplete(currentPoint);
        }
    };
    final Tensor[][] trainingData = getTrainingData(log);
    // MonitoredObject monitoringRoot = new MonitoredObject();
    // TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
    log.h3("Training");
    TestUtil.instrumentPerformance(supervisedNetwork);
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    {
        @Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
        log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
    }
    @Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
    log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
    // log.h3("Metrics");
    // log.code(() -> {
    // return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
    // });
    log.h3("Validation");
    log.p("Here are some re-encoded examples:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        data.validationData().map(labeledObject -> {
            return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
        log.out(log.image(tensor.toImage(), ""));
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 9 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class EncodingProblem method run.

@Nonnull
@Override
public EncodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    Tensor[][] trainingData;
    try {
        trainingData = data.trainingData().map(labeledObject -> {
            return new Tensor[] { new Tensor(features).set(this::random), labeledObject.data };
        }).toArray(i -> new Tensor[i][]);
    } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
    }
    @Nonnull final DAGNetwork imageNetwork = revFactory.vectorToImage(log, features);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(imageNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final PipelineNetwork trainingNetwork = new PipelineNetwork(2);
    @Nullable final DAGNode image = trainingNetwork.add(imageNetwork, trainingNetwork.getInput(0));
    @Nullable final DAGNode softmax = trainingNetwork.add(new SoftmaxActivationLayer(), trainingNetwork.getInput(0));
    trainingNetwork.add(new SumInputsLayer(), trainingNetwork.add(new EntropyLossLayer(), softmax, softmax), trainingNetwork.add(new NthPowerActivationLayer().setPower(1.0 / 2.0), trainingNetwork.add(new MeanSqLossLayer(), image, trainingNetwork.getInput(1))));
    log.h3("Training");
    log.p("We start by training apply a very small population to improve initial convergence performance:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final Tensor[][] primingData = Arrays.copyOfRange(trainingData, 0, 1000);
    @Nonnull final ValidatingTrainer preTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(primingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(primingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        preTrainer.setTimeout(timeoutMinutes / 2, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    log.p("Then our main training phase:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final ValidatingTrainer mainTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(trainingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(trainingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        mainTrainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName().toString() + EncodingProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        log.appendFrontMatterProperty("result_plot", filename, ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // log.file()
    @Nonnull final String modelName = "encoding_model_" + EncodingProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(trainingNetwork.getJson().toString(), modelName, modelName));
    log.h3("Results");
    @Nonnull final PipelineNetwork testNetwork = new PipelineNetwork(2);
    testNetwork.add(imageNetwork, testNetwork.getInput(0));
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        Arrays.stream(trainingData).map(tensorArray -> {
            @Nullable final Tensor predictionSignal = testNetwork.eval(tensorArray).getData().get(0);
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Source", log.image(tensorArray[1].toImage(), ""));
            row.put("Echo", log.image(predictionSignal.toImage(), ""));
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Learned Model Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        trainingNetwork.state().stream().flatMapToDouble(x -> Arrays.stream(x)).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Learned Representation Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        Arrays.stream(trainingData).flatMapToDouble(row -> Arrays.stream(row[0].getData())).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = imageNetwork.eval(input).getData().get(0);
        TestUtil.renderToImages(tensor, true).forEach(img -> {
            log.out(log.image(img, ""));
        });
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) Format(guru.nidi.graphviz.engine.Format) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IOException(java.io.IOException) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNode(com.simiacryptus.mindseye.network.DAGNode) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 10 with DAGNetwork

use of com.simiacryptus.mindseye.network.DAGNetwork in project MindsEye by SimiaCryptus.

the class ValidatingTrainer method reset.

@Nonnull
private ValidatingTrainer reset(@Nonnull final TrainingPhase phase, final long seed) {
    if (!phase.trainingSubject.reseed(seed))
        throw new IterativeStopException();
    phase.orientation.reset();
    phase.trainingSubject.reseed(seed);
    if (phase.trainingSubject.getLayer() instanceof DAGNetwork) {
        ((DAGNetwork) phase.trainingSubject.getLayer()).visitLayers(layer -> {
            if (layer instanceof StochasticComponent)
                ((StochasticComponent) layer).shuffle(StochasticComponent.random.get().nextLong());
        });
    }
    return this;
}
Also used : StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nonnull(javax.annotation.Nonnull)

Aggregations

DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)19 Nonnull (javax.annotation.Nonnull)16 Layer (com.simiacryptus.mindseye.lang.Layer)11 Nullable (javax.annotation.Nullable)11 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 List (java.util.List)10 ArrayList (java.util.ArrayList)9 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)7 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)7 Map (java.util.Map)7 TimeUnit (java.util.concurrent.TimeUnit)7 DAGNode (com.simiacryptus.mindseye.network.DAGNode)6 TestUtil (com.simiacryptus.mindseye.test.TestUtil)6 IntStream (java.util.stream.IntStream)6 Format (guru.nidi.graphviz.engine.Format)5 Graphviz (guru.nidi.graphviz.engine.Graphviz)5 File (java.io.File)5 IOException (java.io.IOException)5 HashMap (java.util.HashMap)5