Search in sources :

Example 1 with MonitoringWrapperLayer

use of com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer in project MindsEye by SimiaCryptus.

the class TestUtil method extractPerformance.

/**
 * Remove performance wrappers.
 *
 * @param log     the log
 * @param network the network
 */
public static void extractPerformance(@Nonnull final NotebookOutput log, @Nonnull final DAGNetwork network) {
    log.p("Per-layer Performance Metrics:");
    log.code(() -> {
        @Nonnull final Map<CharSequence, MonitoringWrapperLayer> metrics = new HashMap<>();
        network.visitNodes(node -> {
            if (node.getLayer() instanceof MonitoringWrapperLayer) {
                @Nullable final MonitoringWrapperLayer layer = node.getLayer();
                Layer inner = layer.getInner();
                String str = inner.toString();
                str += " class=" + inner.getClass().getName();
                // if(inner instanceof MultiPrecision<?>) {
                // str += "; precision=" + ((MultiPrecision) inner).getPrecision().name();
                // }
                metrics.put(str, layer);
            }
        });
        TestUtil.log.info("Performance: \n\t" + metrics.entrySet().stream().sorted(Comparator.comparing(x -> -x.getValue().getForwardPerformance().getMean())).map(e -> {
            @Nonnull final PercentileStatistics performanceF = e.getValue().getForwardPerformance();
            @Nonnull final PercentileStatistics performanceB = e.getValue().getBackwardPerformance();
            return String.format("%.6fs +- %.6fs (%d) <- %s", performanceF.getMean(), performanceF.getStdDev(), performanceF.getCount(), e.getKey()) + (performanceB.getCount() == 0 ? "" : String.format("%n\tBack: %.6fs +- %.6fs (%s)", performanceB.getMean(), performanceB.getStdDev(), performanceB.getCount()));
        }).reduce((a, b) -> a + "\n\t" + b).get());
    });
    removeInstrumentation(network);
}
Also used : Arrays(java.util.Arrays) ScheduledFuture(java.util.concurrent.ScheduledFuture) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) IntUnaryOperator(java.util.function.IntUnaryOperator) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) URI(java.net.URI) Graph(guru.nidi.graphviz.model.Graph) LongToIntFunction(java.util.function.LongToIntFunction) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) BufferedImage(java.awt.image.BufferedImage) UUID(java.util.UUID) ComponentEvent(java.awt.event.ComponentEvent) WindowAdapter(java.awt.event.WindowAdapter) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Collectors(java.util.stream.Collectors) WindowEvent(java.awt.event.WindowEvent) Executors(java.util.concurrent.Executors) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) IntStream(java.util.stream.IntStream) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ActionListener(java.awt.event.ActionListener) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) LinkSource(guru.nidi.graphviz.model.LinkSource) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Supplier(java.util.function.Supplier) JsonUtil(com.simiacryptus.util.io.JsonUtil) MutableNode(guru.nidi.graphviz.model.MutableNode) Charset(java.nio.charset.Charset) Factory(guru.nidi.graphviz.model.Factory) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) WeakReference(java.lang.ref.WeakReference) LinkTarget(guru.nidi.graphviz.model.LinkTarget) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) LongSummaryStatistics(java.util.LongSummaryStatistics) PrintStream(java.io.PrintStream) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) RankDir(guru.nidi.graphviz.attribute.RankDir) IOException(java.io.IOException) FileFilter(javax.swing.filechooser.FileFilter) ActionEvent(java.awt.event.ActionEvent) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics) File(java.io.File) java.awt(java.awt) ComponentAdapter(java.awt.event.ComponentAdapter) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) MonitoredObject(com.simiacryptus.util.MonitoredObject) IntToLongFunction(java.util.function.IntToLongFunction) Link(guru.nidi.graphviz.model.Link) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) javax.swing(javax.swing) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Nullable(javax.annotation.Nullable) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics)

Example 2 with MonitoringWrapperLayer

use of com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer in project MindsEye by SimiaCryptus.

the class TestUtil method instrumentPerformance.

/**
 * Add performance wrappers.
 *
 * @param network the network
 */
public static void instrumentPerformance(@Nonnull final DAGNetwork network) {
    network.visitNodes(node -> {
        Layer layer = node.getLayer();
        if (layer instanceof MonitoringWrapperLayer) {
            ((MonitoringWrapperLayer) layer).shouldRecordSignalMetrics(false);
        } else {
            @Nonnull MonitoringWrapperLayer monitoringWrapperLayer = new MonitoringWrapperLayer(layer).shouldRecordSignalMetrics(false);
            node.setLayer(monitoringWrapperLayer);
            monitoringWrapperLayer.freeRef();
        }
    });
}
Also used : Nonnull(javax.annotation.Nonnull) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer)

Example 3 with MonitoringWrapperLayer

use of com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer in project MindsEye by SimiaCryptus.

the class TestUtil method samplePerformance.

/**
 * Sample performance map.
 *
 * @param network the network
 * @return the map
 */
public static Map<CharSequence, Object> samplePerformance(@Nonnull final DAGNetwork network) {
    @Nonnull final Map<CharSequence, Object> metrics = new HashMap<>();
    network.visitLayers(layer -> {
        if (layer instanceof MonitoringWrapperLayer) {
            MonitoringWrapperLayer monitoringWrapperLayer = (MonitoringWrapperLayer) layer;
            Layer inner = monitoringWrapperLayer.getInner();
            String str = inner.toString();
            str += " class=" + inner.getClass().getName();
            HashMap<CharSequence, Object> row = new HashMap<>();
            row.put("fwd", monitoringWrapperLayer.getForwardPerformance().getMetrics());
            row.put("rev", monitoringWrapperLayer.getBackwardPerformance().getMetrics());
            metrics.put(str, row);
        }
    });
    return metrics;
}
Also used : Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) MonitoredObject(com.simiacryptus.util.MonitoredObject) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)3 LoggingWrapperLayer (com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer)3 MonitoringWrapperLayer (com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer)3 Nonnull (javax.annotation.Nonnull)3 MonitoredObject (com.simiacryptus.util.MonitoredObject)2 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1 Step (com.simiacryptus.mindseye.opt.Step)1 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)1 DoubleStatistics (com.simiacryptus.util.data.DoubleStatistics)1 PercentileStatistics (com.simiacryptus.util.data.PercentileStatistics)1 ScalarStatistics (com.simiacryptus.util.data.ScalarStatistics)1 JsonUtil (com.simiacryptus.util.io.JsonUtil)1 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)1 RankDir (guru.nidi.graphviz.attribute.RankDir)1 Factory (guru.nidi.graphviz.model.Factory)1 Graph (guru.nidi.graphviz.model.Graph)1 Link (guru.nidi.graphviz.model.Link)1