Search in sources :

Example 1 with PercentileStatistics

use of com.simiacryptus.util.data.PercentileStatistics in project MindsEye by SimiaCryptus.

the class MonitoringWrapperLayer method getMetrics.

@Nonnull
@Override
public Map<CharSequence, Object> getMetrics() {
    @Nonnull final HashMap<CharSequence, Object> map = new HashMap<>();
    map.put("class", getInner().getClass().getName());
    map.put("totalBatches", totalBatches);
    map.put("totalItems", totalItems);
    map.put("outputStatistics", forwardSignal.getMetrics());
    map.put("backpropStatistics", backwardSignal.getMetrics());
    if (verbose) {
        map.put("forwardPerformance", forwardPerformance.getMetrics());
        map.put("backwardPerformance", backwardPerformance.getMetrics());
    }
    final double batchesPerItem = totalBatches * 1.0 / totalItems;
    map.put("avgMsPerItem", 1000 * batchesPerItem * forwardPerformance.getMean());
    map.put("medianMsPerItem", 1000 * batchesPerItem * forwardPerformance.getPercentile(0.5));
    final double backpropMean = backwardPerformance.getMean();
    final double backpropMedian = backwardPerformance.getPercentile(0.5);
    map.put("avgMsPerItem_Backward", 1000 * batchesPerItem * backpropMean);
    map.put("medianMsPerItem_Backward", 1000 * batchesPerItem * backpropMedian);
    @Nullable final List<double[]> state = state();
    @Nonnull final ScalarStatistics statistics = new PercentileStatistics();
    for (@Nonnull final double[] s : state) {
        for (final double v : s) {
            statistics.add(v);
        }
    }
    if (statistics.getCount() > 0) {
        @Nonnull final HashMap<CharSequence, Object> weightStats = new HashMap<>();
        weightStats.put("buffers", state.size());
        weightStats.putAll(statistics.getMetrics());
        map.put("weights", weightStats);
    }
    return map;
}
Also used : Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) JsonObject(com.google.gson.JsonObject) MonitoredObject(com.simiacryptus.util.MonitoredObject) Nullable(javax.annotation.Nullable) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics) Nonnull(javax.annotation.Nonnull)

Example 2 with PercentileStatistics

use of com.simiacryptus.util.data.PercentileStatistics in project MindsEye by SimiaCryptus.

the class TestUtil method extractPerformance.

/**
 * Remove performance wrappers.
 *
 * @param log     the log
 * @param network the network
 */
public static void extractPerformance(@Nonnull final NotebookOutput log, @Nonnull final DAGNetwork network) {
    log.p("Per-layer Performance Metrics:");
    log.code(() -> {
        @Nonnull final Map<CharSequence, MonitoringWrapperLayer> metrics = new HashMap<>();
        network.visitNodes(node -> {
            if (node.getLayer() instanceof MonitoringWrapperLayer) {
                @Nullable final MonitoringWrapperLayer layer = node.getLayer();
                Layer inner = layer.getInner();
                String str = inner.toString();
                str += " class=" + inner.getClass().getName();
                // if(inner instanceof MultiPrecision<?>) {
                // str += "; precision=" + ((MultiPrecision) inner).getPrecision().name();
                // }
                metrics.put(str, layer);
            }
        });
        TestUtil.log.info("Performance: \n\t" + metrics.entrySet().stream().sorted(Comparator.comparing(x -> -x.getValue().getForwardPerformance().getMean())).map(e -> {
            @Nonnull final PercentileStatistics performanceF = e.getValue().getForwardPerformance();
            @Nonnull final PercentileStatistics performanceB = e.getValue().getBackwardPerformance();
            return String.format("%.6fs +- %.6fs (%d) <- %s", performanceF.getMean(), performanceF.getStdDev(), performanceF.getCount(), e.getKey()) + (performanceB.getCount() == 0 ? "" : String.format("%n\tBack: %.6fs +- %.6fs (%s)", performanceB.getMean(), performanceB.getStdDev(), performanceB.getCount()));
        }).reduce((a, b) -> a + "\n\t" + b).get());
    });
    removeInstrumentation(network);
}
Also used : Arrays(java.util.Arrays) ScheduledFuture(java.util.concurrent.ScheduledFuture) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) IntUnaryOperator(java.util.function.IntUnaryOperator) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) URI(java.net.URI) Graph(guru.nidi.graphviz.model.Graph) LongToIntFunction(java.util.function.LongToIntFunction) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) BufferedImage(java.awt.image.BufferedImage) UUID(java.util.UUID) ComponentEvent(java.awt.event.ComponentEvent) WindowAdapter(java.awt.event.WindowAdapter) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Collectors(java.util.stream.Collectors) WindowEvent(java.awt.event.WindowEvent) Executors(java.util.concurrent.Executors) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) IntStream(java.util.stream.IntStream) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ActionListener(java.awt.event.ActionListener) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) LinkSource(guru.nidi.graphviz.model.LinkSource) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Supplier(java.util.function.Supplier) JsonUtil(com.simiacryptus.util.io.JsonUtil) MutableNode(guru.nidi.graphviz.model.MutableNode) Charset(java.nio.charset.Charset) Factory(guru.nidi.graphviz.model.Factory) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) WeakReference(java.lang.ref.WeakReference) LinkTarget(guru.nidi.graphviz.model.LinkTarget) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) LongSummaryStatistics(java.util.LongSummaryStatistics) PrintStream(java.io.PrintStream) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) RankDir(guru.nidi.graphviz.attribute.RankDir) IOException(java.io.IOException) FileFilter(javax.swing.filechooser.FileFilter) ActionEvent(java.awt.event.ActionEvent) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics) File(java.io.File) java.awt(java.awt) ComponentAdapter(java.awt.event.ComponentAdapter) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) MonitoredObject(com.simiacryptus.util.MonitoredObject) IntToLongFunction(java.util.function.IntToLongFunction) Link(guru.nidi.graphviz.model.Link) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) javax.swing(javax.swing) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) Nullable(javax.annotation.Nullable) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics)

Aggregations

MonitoredObject (com.simiacryptus.util.MonitoredObject)2 PercentileStatistics (com.simiacryptus.util.data.PercentileStatistics)2 ScalarStatistics (com.simiacryptus.util.data.ScalarStatistics)2 HashMap (java.util.HashMap)2 Nonnull (javax.annotation.Nonnull)2 Nullable (javax.annotation.Nullable)2 JsonObject (com.google.gson.JsonObject)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 LoggingWrapperLayer (com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer)1 MonitoringWrapperLayer (com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer)1 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1 Step (com.simiacryptus.mindseye.opt.Step)1 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)1 DoubleStatistics (com.simiacryptus.util.data.DoubleStatistics)1 JsonUtil (com.simiacryptus.util.io.JsonUtil)1 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)1 RankDir (guru.nidi.graphviz.attribute.RankDir)1