Search in sources :

Example 1 with Component

use of org.deeplearning4j.ui.api.Component in project deeplearning4j by deeplearning4j.

the class TestComponentSerialization method assertSerializable.

private static void assertSerializable(Component component) throws Exception {
    ObjectMapper om = new ObjectMapper();
    String json = om.writeValueAsString(component);
    Component fromJson = om.readValue(json, Component.class);
    //Yes, this is a bit hacky, but lombok equal method doesn't seem to work properly for List<double[]> etc
    assertEquals(component.toString(), fromJson.toString());
}
Also used : Component(org.deeplearning4j.ui.api.Component) ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper)

Example 2 with Component

use of org.deeplearning4j.ui.api.Component in project deeplearning4j by deeplearning4j.

the class EvaluationTools method rocChartToHtml.

/**
     * Given a {@link ROC} instance, render the ROC chart and precision vs. recall charts to a stand-alone HTML file (returned as a String)
     * @param roc  ROC to render
     */
public static String rocChartToHtml(ROC roc) {
    double[][] points = roc.getResultsAsArray();
    Component c = getRocFromPoints(ROC_TITLE, points, roc.getCountActualPositive(), roc.getCountActualNegative(), roc.calculateAUC());
    Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve());
    return StaticPageUtil.renderHTML(c, c2);
}
Also used : Component(org.deeplearning4j.ui.api.Component)

Example 3 with Component

use of org.deeplearning4j.ui.api.Component in project deeplearning4j by deeplearning4j.

the class EvaluationTools method rocChartToHtml.

/**
     * Given a {@link ROCMultiClass} instance and (optionally) names for each class, render the ROC chart to a stand-alone
     * HTML file (returned as a String)
     * @param rocMultiClass  ROC to render
     * @param classNames     Names of the classes. May be null
     */
public static String rocChartToHtml(ROCMultiClass rocMultiClass, List<String> classNames) {
    long[] actualCountPositive = rocMultiClass.getCountActualPositive();
    long[] actualCountNegative = rocMultiClass.getCountActualNegative();
    List<Component> components = new ArrayList<>(actualCountPositive.length);
    for (int i = 0; i < actualCountPositive.length; i++) {
        double[][] points = rocMultiClass.getResultsAsArray(i);
        String headerText = "Class " + i;
        if (classNames != null && classNames.size() > i) {
            headerText += " (" + classNames.get(i) + ")";
        }
        headerText += " vs. All";
        ;
        Component headerDivPad = new ComponentDiv(HEADER_DIV_PAD_STYLE);
        components.add(headerDivPad);
        Component headerDivLeft = new ComponentDiv(HEADER_DIV_TEXT_PAD_STYLE);
        Component headerDiv = new ComponentDiv(HEADER_DIV_STYLE, new ComponentText(headerText, HEADER_TEXT_STYLE));
        Component c = getRocFromPoints(ROC_TITLE, points, actualCountPositive[i], actualCountNegative[i], rocMultiClass.calculateAUC(i));
        Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rocMultiClass.getPrecisionRecallCurve(i));
        components.add(headerDivLeft);
        components.add(headerDiv);
        components.add(c);
        components.add(c2);
    }
    return StaticPageUtil.renderHTML(components);
}
Also used : ArrayList(java.util.ArrayList) Component(org.deeplearning4j.ui.api.Component) ComponentDiv(org.deeplearning4j.ui.components.component.ComponentDiv) ComponentText(org.deeplearning4j.ui.components.text.ComponentText)

Example 4 with Component

use of org.deeplearning4j.ui.api.Component in project deeplearning4j by deeplearning4j.

the class StatsUtils method exportStatsAsHTML.

/**
     * Generate and export a HTML representation (including charts, etc) of the Spark training statistics<br>
     * This overload is for writing to an output stream
     *
     * @param sparkTrainingStats Stats to generate HTML page for
     * @param maxTimelineSizeMs  maximum amount of activity to show in a single timeline plot (multiple plots will be used if training exceeds this amount of time)
     * @throws Exception IO errors or error generating HTML file
     */
public static void exportStatsAsHTML(SparkTrainingStats sparkTrainingStats, long maxTimelineSizeMs, OutputStream outputStream) throws Exception {
    Set<String> keySet = sparkTrainingStats.getKeySet();
    List<Component> components = new ArrayList<>();
    StyleChart styleChart = new StyleChart.Builder().backgroundColor(Color.WHITE).width(700, LengthUnit.Px).height(400, LengthUnit.Px).build();
    StyleText styleText = new StyleText.Builder().color(Color.BLACK).fontSize(20).build();
    Component headerText = new ComponentText("Deeplearning4j - Spark Training Analysis", styleText);
    Component header = new ComponentDiv(new StyleDiv.Builder().height(40, LengthUnit.Px).width(100, LengthUnit.Percent).build(), headerText);
    components.add(header);
    Set<String> keySetInclude = new HashSet<>();
    for (String s : keySet) if (sparkTrainingStats.defaultIncludeInPlots(s))
        keySetInclude.add(s);
    Collections.addAll(components, getTrainingStatsTimelineChart(sparkTrainingStats, keySetInclude, maxTimelineSizeMs));
    for (String s : keySet) {
        List<EventStats> list = new ArrayList<>(sparkTrainingStats.getValue(s));
        Collections.sort(list, new StartTimeComparator());
        double[] x = new double[list.size()];
        double[] duration = new double[list.size()];
        double minDur = Double.MAX_VALUE;
        double maxDur = -Double.MAX_VALUE;
        for (int i = 0; i < duration.length; i++) {
            x[i] = i;
            duration[i] = list.get(i).getDurationMs();
            minDur = Math.min(minDur, duration[i]);
            maxDur = Math.max(maxDur, duration[i]);
        }
        Component line = new ChartLine.Builder(s, styleChart).addSeries("Duration", x, duration).setYMin(minDur == maxDur ? minDur - 1 : null).setYMax(minDur == maxDur ? minDur + 1 : null).build();
        //Also build a histogram...
        Component hist = null;
        if (minDur != maxDur && !list.isEmpty())
            hist = getHistogram(duration, 20, s, styleChart);
        Component[] temp;
        if (hist != null) {
            temp = new Component[] { line, hist };
        } else {
            temp = new Component[] { line };
        }
        components.add(new ComponentDiv(new StyleDiv.Builder().width(100, LengthUnit.Percent).build(), temp));
        //TODO this is really ugly
        if (!list.isEmpty() && (list.get(0) instanceof ExampleCountEventStats || list.get(0) instanceof PartitionCountEventStats)) {
            boolean exCount = list.get(0) instanceof ExampleCountEventStats;
            double[] y = new double[list.size()];
            double miny = Double.MAX_VALUE;
            double maxy = -Double.MAX_VALUE;
            for (int i = 0; i < y.length; i++) {
                y[i] = (exCount ? ((ExampleCountEventStats) list.get(i)).getTotalExampleCount() : ((PartitionCountEventStats) list.get(i)).getNumPartitions());
                miny = Math.min(miny, y[i]);
                maxy = Math.max(maxy, y[i]);
            }
            String title = s + " / " + (exCount ? "Number of Examples" : "Number of Partitions");
            Component line2 = new ChartLine.Builder(title, styleChart).addSeries((exCount ? "Examples" : "Partitions"), x, y).setYMin(miny == maxy ? miny - 1 : null).setYMax(miny == maxy ? miny + 1 : null).build();
            //Also build a histogram...
            Component hist2 = null;
            if (miny != maxy)
                hist2 = getHistogram(y, 20, title, styleChart);
            Component[] temp2;
            if (hist2 != null) {
                temp2 = new Component[] { line2, hist2 };
            } else {
                temp2 = new Component[] { line2 };
            }
            components.add(new ComponentDiv(new StyleDiv.Builder().width(100, LengthUnit.Percent).build(), temp2));
        }
    }
    String html = StaticPageUtil.renderHTML(components);
    outputStream.write(html.getBytes("UTF-8"));
}
Also used : StyleText(org.deeplearning4j.ui.components.text.style.StyleText) StyleChart(org.deeplearning4j.ui.components.chart.style.StyleChart) Component(org.deeplearning4j.ui.api.Component) ComponentDiv(org.deeplearning4j.ui.components.component.ComponentDiv) ComponentText(org.deeplearning4j.ui.components.text.ComponentText)

Example 5 with Component

use of org.deeplearning4j.ui.api.Component in project deeplearning4j by deeplearning4j.

the class StatsUtils method getTrainingStatsTimelineChart.

private static Component[] getTrainingStatsTimelineChart(SparkTrainingStats stats, Set<String> includeSet, long maxDurationMs) {
    Set<Tuple3<String, String, Long>> uniqueTuples = new HashSet<>();
    Set<String> machineIDs = new HashSet<>();
    Set<String> jvmIDs = new HashSet<>();
    Map<String, String> machineShortNames = new HashMap<>();
    Map<String, String> jvmShortNames = new HashMap<>();
    long earliestStart = Long.MAX_VALUE;
    long latestEnd = Long.MIN_VALUE;
    for (String s : includeSet) {
        List<EventStats> list = stats.getValue(s);
        for (EventStats e : list) {
            machineIDs.add(e.getMachineID());
            jvmIDs.add(e.getJvmID());
            uniqueTuples.add(new Tuple3<String, String, Long>(e.getMachineID(), e.getJvmID(), e.getThreadID()));
            earliestStart = Math.min(earliestStart, e.getStartTime());
            latestEnd = Math.max(latestEnd, e.getStartTime() + e.getDurationMs());
        }
    }
    int count = 0;
    for (String s : machineIDs) {
        machineShortNames.put(s, "PC " + count++);
    }
    count = 0;
    for (String s : jvmIDs) {
        jvmShortNames.put(s, "JVM " + count++);
    }
    int nLanes = uniqueTuples.size();
    List<Tuple3<String, String, Long>> outputOrder = new ArrayList<>(uniqueTuples);
    Collections.sort(outputOrder, new TupleComparator());
    Color[] colors = getColors(includeSet.size());
    Map<String, Color> colorMap = new HashMap<>();
    count = 0;
    for (String s : includeSet) {
        colorMap.put(s, colors[count++]);
    }
    //Create key for charts:
    List<Component> tempList = new ArrayList<>();
    for (String s : includeSet) {
        String key = stats.getShortNameForKey(s) + " - " + s;
        tempList.add(new ComponentDiv(new StyleDiv.Builder().backgroundColor(colorMap.get(s)).width(33.3, LengthUnit.Percent).height(25, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build(), new ComponentText(key, new StyleText.Builder().fontSize(11).build())));
    }
    Component key = new ComponentDiv(new StyleDiv.Builder().width(100, LengthUnit.Percent).build(), tempList);
    //How many charts?
    int nCharts = (int) ((latestEnd - earliestStart) / maxDurationMs);
    if (nCharts < 1)
        nCharts = 1;
    long[] chartStartTimes = new long[nCharts];
    long[] chartEndTimes = new long[nCharts];
    for (int i = 0; i < nCharts; i++) {
        chartStartTimes[i] = earliestStart + i * maxDurationMs;
        chartEndTimes[i] = earliestStart + (i + 1) * maxDurationMs;
    }
    List<List<List<ChartTimeline.TimelineEntry>>> entriesByLane = new ArrayList<>();
    for (int c = 0; c < nCharts; c++) {
        entriesByLane.add(new ArrayList<List<ChartTimeline.TimelineEntry>>());
        for (int i = 0; i < nLanes; i++) {
            entriesByLane.get(c).add(new ArrayList<ChartTimeline.TimelineEntry>());
        }
    }
    for (String s : includeSet) {
        List<EventStats> list = stats.getValue(s);
        for (EventStats e : list) {
            if (e.getDurationMs() == 0)
                continue;
            long start = e.getStartTime();
            long end = start + e.getDurationMs();
            int chartIdx = -1;
            for (int j = 0; j < nCharts; j++) {
                if (start >= chartStartTimes[j] && start < chartEndTimes[j]) {
                    chartIdx = j;
                }
            }
            if (chartIdx == -1)
                chartIdx = nCharts - 1;
            Tuple3<String, String, Long> tuple = new Tuple3<>(e.getMachineID(), e.getJvmID(), e.getThreadID());
            int idx = outputOrder.indexOf(tuple);
            Color c = colorMap.get(s);
            //                ChartTimeline.TimelineEntry entry = new ChartTimeline.TimelineEntry(null, start, end, c);
            ChartTimeline.TimelineEntry entry = new ChartTimeline.TimelineEntry(stats.getShortNameForKey(s), start, end, c);
            entriesByLane.get(chartIdx).get(idx).add(entry);
        }
    }
    //Sort each lane by start time:
    for (int i = 0; i < nCharts; i++) {
        for (List<ChartTimeline.TimelineEntry> l : entriesByLane.get(i)) {
            Collections.sort(l, new Comparator<ChartTimeline.TimelineEntry>() {

                @Override
                public int compare(ChartTimeline.TimelineEntry o1, ChartTimeline.TimelineEntry o2) {
                    return Long.compare(o1.getStartTimeMs(), o2.getStartTimeMs());
                }
            });
        }
    }
    StyleChart sc = new StyleChart.Builder().width(1280, LengthUnit.Px).height(35 * nLanes + (60 + 20 + 25), LengthUnit.Px).margin(LengthUnit.Px, 60, 20, 200, //top, bottom, left, right
    10).build();
    List<Component> list = new ArrayList<>(nCharts);
    for (int j = 0; j < nCharts; j++) {
        ChartTimeline.Builder b = new ChartTimeline.Builder("Timeline: Training Activities", sc);
        int i = 0;
        for (List<ChartTimeline.TimelineEntry> l : entriesByLane.get(j)) {
            Tuple3<String, String, Long> t3 = outputOrder.get(i);
            String name = machineShortNames.get(t3._1()) + ", " + jvmShortNames.get(t3._2()) + ", Thread " + t3._3();
            b.addLane(name, l);
            i++;
        }
        list.add(b.build());
    }
    list.add(key);
    return list.toArray(new Component[list.size()]);
}
Also used : StyleText(org.deeplearning4j.ui.components.text.style.StyleText) StyleChart(org.deeplearning4j.ui.components.chart.style.StyleChart) List(java.util.List) Component(org.deeplearning4j.ui.api.Component) StyleDiv(org.deeplearning4j.ui.components.component.style.StyleDiv) Tuple3(scala.Tuple3) ChartTimeline(org.deeplearning4j.ui.components.chart.ChartTimeline) ComponentDiv(org.deeplearning4j.ui.components.component.ComponentDiv) ComponentText(org.deeplearning4j.ui.components.text.ComponentText)

Aggregations

Component (org.deeplearning4j.ui.api.Component)8 ComponentDiv (org.deeplearning4j.ui.components.component.ComponentDiv)5 ComponentText (org.deeplearning4j.ui.components.text.ComponentText)5 StyleChart (org.deeplearning4j.ui.components.chart.style.StyleChart)4 StyleText (org.deeplearning4j.ui.components.text.style.StyleText)4 ArrayList (java.util.ArrayList)3 StyleDiv (org.deeplearning4j.ui.components.component.style.StyleDiv)3 ObjectMapper (org.nd4j.shade.jackson.databind.ObjectMapper)3 Style (org.deeplearning4j.ui.api.Style)2 StyleAccordion (org.deeplearning4j.ui.components.decorator.style.StyleAccordion)2 ComponentTable (org.deeplearning4j.ui.components.table.ComponentTable)2 StyleTable (org.deeplearning4j.ui.components.table.style.StyleTable)2 Test (org.junit.Test)2 Configuration (freemarker.template.Configuration)1 Template (freemarker.template.Template)1 Version (freemarker.template.Version)1 File (java.io.File)1 StringWriter (java.io.StringWriter)1 Writer (java.io.Writer)1 List (java.util.List)1